From 014c83621f4d91752c95378750808d8c26da0fdb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Christian=20D=C3=B6ring?= Date: Mon, 14 Oct 2024 08:10:12 +0200 Subject: [PATCH 1/3] Frozen Functions Implementation Added comment explaining traverse with `rw` arguments Added comment about c++ traversal Added recursion guards Simplified comment on TraverseCallback operator() Removed redundant profiling loop Fixed recursion_guard by wrapping it in namespace Fixed typing import Improved handling of nullptr variant/domains Removed warn log Added JitFlag FreezingTraverseScope Using try_cast in get_traversable_base Renamed to EnableObjectTraversal flag Added comment for traverse_py_cb_ro_impl Added rw value for traversal of trampolines Removed outdated function Excluding texture from traversal outside of frozen function Testing fix for windows compilation Fixing windows compilation bugs Added base value test for custom_type_ext traversal Added nested traversal test Exposing frozen function related flags to python Improved warning about non-drjit types Added option to specify backend if no input variable is specified Allow for changes in the last dimension of tensor shapes Marked payload and fn as used Reverted comment Formatting Removed printing of keys Fixed backend test Added method to detect recording frozen functions with the wrong backend Added comments for arguments of dr.freeze decorator Added flag test to DR_TRAVERSE_CB_RO/RW macros Removed test for `EnableObjectTraversal` flag in `DR_TRAVERSE_CB` macro Fixed tensor freezing indexing issue Removed deprecated logging code Added comment about tensor shape inference Using Wenzel's documentation, and renamed arguments Added comment regarding `frozen_function_tp_traverse` and `frozen_function_clear` Fixed typo Added warning text to documentation --- drjit/__init__.py | 254 ++- ext/drjit-core | 2 +- include/drjit/array_traverse.h | 63 +- include/drjit/autodiff.h | 7 +- include/drjit/custom.h | 1 + include/drjit/extra.h | 5 + include/drjit/python.h | 63 +- include/drjit/texture.h | 46 +- include/drjit/traversable_base.h | 254 +++ src/extra/autodiff.cpp | 13 + src/python/CMakeLists.txt | 1 + src/python/apply.cpp | 59 +- src/python/apply.h | 40 +- src/python/common.h | 7 + src/python/detail.cpp | 117 +- src/python/docstr.rst | 29 + src/python/freeze.cpp | 1551 ++++++++++++++++ src/python/freeze.h | 529 ++++++ src/python/main.cpp | 7 +- src/python/texture.h | 2 + src/python/tracker.cpp | 9 +- tests/call_ext.cpp | 45 +- tests/custom_type_ext.cpp | 103 +- tests/test_custom_type_ext.py | 94 +- tests/test_freeze.py | 2824 ++++++++++++++++++++++++++++++ tests/while_loop_ext.cpp | 7 +- 26 files changed, 6070 insertions(+), 62 deletions(-) create mode 100644 include/drjit/traversable_base.h create mode 100644 src/python/freeze.cpp create mode 100644 src/python/freeze.h create mode 100644 tests/test_freeze.py diff --git a/drjit/__init__.py b/drjit/__init__.py index 81c907d41..18d60d063 100644 --- a/drjit/__init__.py +++ b/drjit/__init__.py @@ -19,12 +19,12 @@ import sys as _sys if _sys.version_info < (3, 11): try: - from typing_extensions import overload, Optional, Type, Tuple, List, Sequence, Union, Literal, Callable + from typing_extensions import overload, Optional, Type, Tuple, List, Sequence, Union, Literal, Callable, TypeVar except ImportError: raise RuntimeError( "Dr.Jit requires the 'typing_extensions' package on Python <3.11") else: - from typing import overload, Optional, Type, Tuple, List, Sequence, Union, Literal, Callable + from typing import overload, Optional, Type, Tuple, List, Sequence, Union, Literal, Callable, TypeVar from .ast import syntax, hint from .interop import wrap @@ -2494,6 +2494,256 @@ def binary_search(start, end, pred): return start +# Represents the frozen function passed to the decorator without arguments +F = TypeVar("F") +# Represents the frozen function passed to the decorator with arguments +F2 = TypeVar("F2") + +@overload +def freeze( + f: None = None, + *, + state_fn: Optional[Callable], + limit: Optional[int] = None, + warn_after: int = 10, + backend: Optional[JitBackend] = None, +) -> Callable[[F], F]: + """ + Decorator to "freeze" functions, which improves efficiency by removing + repeated JIT tracing overheads. + + In general, Dr.Jit traces computation and then compiles and launches kernels + containing this trace (see the section on :ref:`evaluation ` for + details). While the compilation step can often be skipped via caching, the + tracing cost can still be significant especially when repeatedly evaluating + complex models, e.g., as part of an optimization loop. + + The :py:func:`@dr.freeze ` decorator adresses this problem by + altogether removing the need to trace repeatedly. For example, consider the + following decorated function: + + .. code-block:: python + + @dr.freeze + def f(x, y, z): + return ... # Complicated code involving the arguments + + Dr.Jit will trace the first call to the decorated function ``f()``, while + collecting additional information regarding the nature of the function's inputs + and regarding the CPU/GPU kernel launches representing the body of ``f()``. + + If the function is subsequently called with *compatible* arguments (more on + this below), it will immediately launch the previously made CPU/GPU kernels + without re-tracing, which can substantially improve performance. + + When :py:func:`@dr.freeze ` detects *incompatibilities* (e.g., ``x`` + having a different type compared to the previous call), it will conservatively + re-trace the body and keep track of another potential input configuration. + + Frozen functions support arbitrary :ref:`PyTrees ` as function + arguments and return values. + + The following may trigger re-tracing: + + - Changes in the **type** of an argument or :ref:`PyTree ` element. + - Changes in the **length** of a container (``list``, ``tuple``, ``dict``). + - Changes of **dictionary keys** or **field names** of dataclasses. + - Changes in the AD status (:py:`dr.grad_enabled() `) of a variable. + - Changes of (non-PyTree) **Python objects**, as detected by mismatching ``hash()``. + + The following more technical conditions also trigger re-tracing: + - A Dr.Jit variable changes from/to a **scalar** configuration (size ``1``). + - The sets of variables of the same size change. In the example above, this + would be the case if ``len(x) == len(y)`` in one call, and ``len(x) != len(y)`` + subsequently. + - When Dr.Jit variables reference external memory (e.g. mapped NumPy arrays), the + memory can be aligned or unaligned. A re-tracing step is needed when this + status changes. + + These all correspond to situations where the generated kernel code may need to + change, and the system conservatively re-traces to ensure correctness. + + Frozen functions support arguments with a different variable *width* (see + :py:func:`dr.with() `) without re-tracing, as long as the sets of + variables of the same width stay consistent. + + Some constructions are problematic and should be avoided in frozen functions. + + - The function :py:func:`dr.width() ` returns an integer literal + that may be merged into the generated code. If the frozen function is later + rerun with differently-sized arguments, the executed kernels will still + reference the old size. One exception to this rule are constructions like + `dr.arange(UInt32, dr.width(a))`, where the result only implicitly depends on + the width value. + + **Advanced features**. The :py:func:`@dr.freeze ` decorator takes + several optional parameters that are helpful in certain situations. + + - **Warning when re-tracing happens too often**: Incompatible arguments trigger + re-tracing, which can mask issues where *accidentally* incompatible arguments + keep :py:func:`@dr.freeze ` from producing the expected + performance benefits. + + In such situations, it can be helpful to warn and identify changing + parameters by name. This feature is enabled and set to ``10`` by default. + + .. code-block:: pycon + + >>> @dr.freeze(warn_after=1) + >>> def f(x): + ... return x + ... + >>> f(Int(1)) + >>> f(Float(1)) + The frozen function has been recorded 2 times, this indicates a problem + with how the frozen function is being called. For example, calling it + with changing python values such as an index. For more information about + which variables changed set the log level to ``LogLevel::Debug``. + + - **Limiting memory usage**. Storing kernels for many possible input + configuration requires device memory, which can become problematic. Set the + ``limit=`` parameter to enable a LRU cache. This is useful when calls to a + function are mostly compatible but require occasional re-tracing. + + Args: + limit (Optional[int]): An optional integer specifying the maximum number of + stored configurations. Once this limit is reached, incompatible calls + requiring re-tracing will cause the last used configuration to be dropped. + + warn_after (int): When the number of re-tracing steps exceeds this value, + Dr.Jit will generate a warning that explains which variables changed + between calls to the function. + + state_fn (Optional[Callable]): This optional callable can specify additional + state to identifies the configuration. ``state_fn`` will be called with + the same arguments as that of the decorated function. It should return a + traversable object (e.g., a list or tuple) that is conceptually treated + as if it was another input of the function. + + backend (Optional[JitBackend]): If no inputs are given when calling the + frozen function, the backend used has to be specified using this argument. + It must match the backend used for computation within the function. + """ + + +@overload +def freeze( + f: F, + *, + state_fn: Optional[Callable] = None, + limit: Optional[int] = None, + warn_after: int = 10, + backend: Optional[JitBackend] = None, +) -> F: ... + + +def freeze( + f: Optional[F] = None, + *, + state_fn: Optional[Callable] = None, + limit: Optional[int] = None, + warn_after: int = 10, + backend: Optional[JitBackend] = None, +) -> Union[F, Callable[[F2], F2]]: + limit = limit if limit is not None else -1 + backend = backend if backend is not None else JitBackend.Invalid + + def decorator(f): + """ + Internal decorator, returned in ``dr.freeze`` was used with arguments. + """ + import functools + import inspect + + def inner(closure, *args, **kwargs): + """ + This inner function is the one that gets actually frozen. It receives + any additional state such as closures or state specified with the + ``state`` lambda, and allows for traversal of it. + """ + return f(*args, **kwargs) + + class FrozenFunction: + def __init__(self, f) -> None: + closure = inspect.getclosurevars(f) + self.closure = (closure.nonlocals, closure.globals) + self.frozen = detail.FrozenFunction( + inner, + limit, + warn_after, + backend, + ) + + def __call__(self, *args, **kwargs): + _state = state_fn(*args, **kwargs) if state_fn is not None else None + return self.frozen([self.closure, _state], *args, **kwargs) + + @property + def n_recordings(self): + """ + Represents the number of times the function was recorded. This + includes occasions where it was recorded due to a dry-run failing. + It does not necessarily correspond to the number of recordings + currently cached see ``n_cached_recordings`` for that. + """ + return self.frozen.n_recordings + + @property + def n_cached_recordings(self): + """ + Represents the number of recordings currently cached of the frozen + function. If a recording fails in dry-run mode, it will not create + a new recording, but replace the recording that was attemted to be + replayed. The number of recordings can also be limited with + the ``max_cache_size`` argument. + """ + return self.frozen.n_cached_recordings + + def clear(self): + """ + Clears the recordings of the frozen function, and resets the + ``n_recordings`` counter. The reference to the function is still + kept, and the frozen function can be called again to re-trace + new recordings. + """ + return self.frozen.clear() + + def __get__(self, obj, type=None): + if obj is None: + return self + else: + return FrozenMethod(self.frozen, self.closure, obj) + + class FrozenMethod(FrozenFunction): + """ + A FrozenMethod currying the object into the __call__ method. + + If the ``freeze`` decorator is applied to a method of some class, it has + to call the internal frozen function with the ``self`` argument. To this + end we implement the ``__get__`` method of the frozen function, to + return a ``FrozenMethod``, which holds a reference to the object. + The ``__call__`` method of the ``FrozenMethod`` then supplies the object + in addition to the arguments to the internal function. + """ + def __init__(self, frozen, closure, obj) -> None: + self.obj = obj + self.frozen = frozen + self.closure = closure + + def __call__(self, *args, **kwargs): + _state = state_fn(self.obj, *args, **kwargs) if state_fn is not None else None + return self.frozen([self.closure, _state], self.obj, *args, **kwargs) + + return functools.wraps(f)(FrozenFunction(f)) + + if f is not None: + return decorator(f) + else: + return decorator + + +del F +del F2 def assert_true( cond, diff --git a/ext/drjit-core b/ext/drjit-core index e63e186ee..4b7c51d0e 160000 --- a/ext/drjit-core +++ b/ext/drjit-core @@ -1 +1 @@ -Subproject commit e63e186eefc2647e00efc2e1aeae3c66d996714e +Subproject commit 4b7c51d0e9eb3fa622e1f65f6ef1a79bae51b96e diff --git a/include/drjit/array_traverse.h b/include/drjit/array_traverse.h index 8fee8681e..6ae82a5f2 100644 --- a/include/drjit/array_traverse.h +++ b/include/drjit/array_traverse.h @@ -15,6 +15,8 @@ #pragma once +#include + #define DRJIT_STRUCT_NODEF(Name, ...) \ Name(const Name &) = default; \ Name(Name &&) = default; \ @@ -140,6 +142,18 @@ namespace detail { using det_traverse_1_cb_rw = decltype(T(nullptr)->traverse_1_cb_rw(nullptr, nullptr)); + template + using det_get = decltype(std::declval().get()); + + template + using det_const_get = decltype(std::declval().get()); + + template + using det_begin = decltype(std::declval().begin()); + + template + using det_end = decltype(std::declval().end()); + inline drjit::string get_label(const char *s, size_t i) { auto skip = [](char c) { return c == ' ' || c == '\r' || c == '\n' || c == '\t' || c == ','; @@ -180,10 +194,17 @@ template auto labels(const T &v) { } template -void traverse_1_fn_ro(const Value &value, void *payload, void (*fn)(void *, uint64_t)) { - (void) payload; (void) fn; +void traverse_1_fn_ro(const Value &value, void *payload, + void (*fn)(void *, uint64_t, const char *, + const char *)) { + DRJIT_MARK_USED(payload); + DRJIT_MARK_USED(fn); if constexpr (is_jit_v && depth_v == 1) { - fn(payload, value.index_combined()); + if constexpr(Value::IsClass) + fn(payload, value.index_combined(), Value::CallSupport::Variant, + Value::CallSupport::Domain); + else + fn(payload, value.index_combined(), "", ""); } else if constexpr (is_traversable_v) { traverse_1(fields(value), [payload, fn](auto &x) { traverse_1_fn_ro(x, payload, fn); @@ -198,14 +219,34 @@ void traverse_1_fn_ro(const Value &value, void *payload, void (*fn)(void *, uint is_detected_v) { if (value) value->traverse_1_cb_ro(payload, fn); + + } else if constexpr (is_detected_v && + is_detected_v) { + for (auto elem : value) { + traverse_1_fn_ro(elem, payload, fn); + } + } else if constexpr (is_detected_v) { + const auto *tmp = value.get(); + traverse_1_fn_ro(tmp, payload, fn); + } else if constexpr (is_detected_v) { + value.traverse_1_cb_ro(payload, fn); } } template -void traverse_1_fn_rw(Value &value, void *payload, uint64_t (*fn)(void *, uint64_t)) { - (void) payload; (void) fn; +void traverse_1_fn_rw(Value &value, void *payload, + uint64_t (*fn)(void *, uint64_t, const char *, + const char *)) { + DRJIT_MARK_USED(payload); + DRJIT_MARK_USED(fn); if constexpr (is_jit_v && depth_v == 1) { - value = Value::borrow((typename Value::Index) fn(payload, value.index_combined())); + if constexpr(Value::IsClass) + value = Value::borrow((typename Value::Index) fn( + payload, value.index_combined(), Value::CallSupport::Variant, + Value::CallSupport::Domain)); + else + value = Value::borrow((typename Value::Index) fn( + payload, value.index_combined(), "", "")); } else if constexpr (is_traversable_v) { traverse_1(fields(value), [payload, fn](auto &x) { traverse_1_fn_rw(x, payload, fn); @@ -220,6 +261,16 @@ void traverse_1_fn_rw(Value &value, void *payload, uint64_t (*fn)(void *, uint64 is_detected_v) { if (value) value->traverse_1_cb_rw(payload, fn); + } else if constexpr (is_detected_v && + is_detected_v) { + for (auto elem : value) { + traverse_1_fn_rw(elem, payload, fn); + } + } else if constexpr (is_detected_v) { + auto *tmp = value.get(); + traverse_1_fn_rw(tmp, payload, fn); + } else if constexpr (is_detected_v) { + value.traverse_1_cb_rw(payload, fn); } } diff --git a/include/drjit/autodiff.h b/include/drjit/autodiff.h index ad22d6303..6049ca42e 100644 --- a/include/drjit/autodiff.h +++ b/include/drjit/autodiff.h @@ -973,7 +973,9 @@ NAMESPACE_BEGIN(detail) /// Internal operations for traversing nested data structures and fetching or /// storing indices. Used in ``call.h`` and ``loop.h``. -template void collect_indices_fn(void *p, uint64_t index) { +template +void collect_indices_fn(void *p, uint64_t index, const char * /*variant*/, + const char * /*domain*/) { vector &indices = *(vector *) p; if constexpr (IncRef) index = ad_var_inc_ref(index); @@ -985,7 +987,8 @@ struct update_indices_payload { size_t &pos; }; -inline uint64_t update_indices_fn(void *p, uint64_t) { +inline uint64_t update_indices_fn(void *p, uint64_t, const char * /*variant*/, + const char * /*domain*/) { update_indices_payload &payload = *(update_indices_payload *) p; return payload.indices[payload.pos++]; } diff --git a/include/drjit/custom.h b/include/drjit/custom.h index 7a38ec3dd..143a2d7d8 100644 --- a/include/drjit/custom.h +++ b/include/drjit/custom.h @@ -20,6 +20,7 @@ #include #include +#include NAMESPACE_BEGIN(drjit) NAMESPACE_BEGIN(detail) diff --git a/include/drjit/extra.h b/include/drjit/extra.h index d2f24dd00..9df259d2b 100644 --- a/include/drjit/extra.h +++ b/include/drjit/extra.h @@ -250,6 +250,11 @@ extern DRJIT_EXTRA_EXPORT bool ad_release_one_output(drjit::detail::CustomOpBase extern DRJIT_EXTRA_EXPORT void ad_copy_implicit_deps(drjit::vector &, bool input); +/// Retrieve a list of ad indices, that are the target of edges, that have been +/// postponed by the current scope +extern DRJIT_EXTRA_EXPORT void ad_scope_postponed(drjit::vector *dst); + + /// Kahan-compensated floating point atomic scatter-addition extern DRJIT_EXTRA_EXPORT void ad_var_scatter_add_kahan(uint64_t *target_1, uint64_t *target_2, uint64_t value, diff --git a/include/drjit/python.h b/include/drjit/python.h index 29229bdaa..7e5d9c339 100644 --- a/include/drjit/python.h +++ b/include/drjit/python.h @@ -54,6 +54,7 @@ #include #include #include +#include NAMESPACE_BEGIN(drjit) struct ArrayBinding; @@ -1057,25 +1058,73 @@ template void bind_all(ArrayBinding &b) { // Expose already existing object tree traversal callbacks (T::traverse_1_..) in Python. // This functionality is needed to traverse custom/opaque C++ classes and correctly // update their members when they are used in vectorized loops, function calls, etc. -template auto& bind_traverse(nanobind::class_ &cls) { +template auto &bind_traverse(nanobind::class_ &cls) +{ namespace nb = nanobind; - struct Payload { nb::callable c; }; + struct Payload { + nb::callable c; + }; + + static_assert(std::is_base_of_v); cls.def("_traverse_1_cb_ro", [](const T *self, nb::callable c) { Payload payload{ std::move(c) }; - self->traverse_1_cb_ro((void *) &payload, [](void *p, uint64_t index) { - ((Payload *) p)->c(index); - }); + self->traverse_1_cb_ro((void *) &payload, + [](void *p, uint64_t index, const char *variant, const char *domain) { + ((Payload *) p)->c(index, variant, domain); + }); }); cls.def("_traverse_1_cb_rw", [](T *self, nb::callable c) { Payload payload{ std::move(c) }; - self->traverse_1_cb_rw((void *) &payload, [](void *p, uint64_t index) { - return nb::cast(((Payload *) p)->c(index)); + self->traverse_1_cb_rw((void *) &payload, [](void *p, uint64_t index, + const char *variant, + const char *domain) { + return nb::cast( + ((Payload *) p)->c(index, variant, domain)); }); }); return cls; } +inline void traverse_py_cb_ro(const TraversableBase *base, void *payload, + void (*fn)(void *, uint64_t, const char *variant, + const char *domain)) { + namespace nb = nanobind; + nb::handle self = base->self_py(); + if (!self) + return; + + auto detail = nb::module_::import_("drjit.detail"); + nb::callable traverse_py_cb_ro_fn = + nb::borrow(nb::getattr(detail, "traverse_py_cb_ro")); + + traverse_py_cb_ro_fn(self, + nb::cpp_function([&](uint64_t index, const char *variant, + const char *domain) { + fn(payload, index, variant, domain); + })); +} + +inline void traverse_py_cb_rw(TraversableBase *base, void *payload, + uint64_t (*fn)(void *, uint64_t, const char *, + const char *)) { + + namespace nb = nanobind; + nb::handle self = base->self_py(); + if (!self) + return; + + auto detail = nb::module_::import_("drjit.detail"); + nb::callable traverse_py_cb_rw_fn = + nb::borrow(nb::getattr(detail, "traverse_py_cb_rw")); + + traverse_py_cb_rw_fn(self, + nb::cpp_function([&](uint64_t index, const char *variant, + const char *domain) { + return fn(payload, index, variant, domain); + })); +} + NAMESPACE_END(drjit) diff --git a/include/drjit/texture.h b/include/drjit/texture.h index ff6b6f3a0..c67dcbdd1 100644 --- a/include/drjit/texture.h +++ b/include/drjit/texture.h @@ -18,6 +18,8 @@ #include #include #include +#include +#include #pragma once @@ -42,7 +44,7 @@ enum class CudaTextureFormat : uint32_t { Float16 = 1, /// Half precision storage format }; -template class Texture { +template class Texture : TraversableBase { public: static constexpr bool IsCUDA = is_cuda_v; static constexpr bool IsDiff = is_diff_v; @@ -1591,6 +1593,48 @@ template class Texture { mutable bool m_tensor_dirty = false; /* Flag to indicate whether public-facing unpadded tensor needs to be updated */ + +public: + void + traverse_1_cb_ro(void *payload, + drjit ::detail ::traverse_callback_ro fn) const override { + // Only traverse the texture for frozen functions, since accidentally + // traversing the scene in loops or vcalls can cause issues. + if (!jit_flag(JitFlag::EnableObjectTraversal)) + return; + + DRJIT_MAP(DR_TRAVERSE_MEMBER_RO, m_value, m_unpadded_value, + m_resolution_opaque, m_inv_resolution); + if constexpr (HasCudaTexture) { + uint32_t n_textures = 1 + ((m_channels - 1) / 4); + std::vector indices(n_textures); + jit_cuda_tex_get_indices(m_handle, indices.data()); + for (uint32_t i = 0; i < n_textures; i++) { + fn(payload, indices[i], "", ""); + } + } + } + void traverse_1_cb_rw(void *payload, + drjit ::detail ::traverse_callback_rw fn) override { + // Only traverse the texture for frozen functions, since accidentally + // traversing the scene in loops or vcalls can cause issues. + if (!jit_flag(JitFlag::EnableObjectTraversal)) + return; + + DRJIT_MAP(DR_TRAVERSE_MEMBER_RW, m_value, m_unpadded_value, + m_resolution_opaque, m_inv_resolution); + if constexpr (HasCudaTexture) { + uint32_t n_textures = 1 + ((m_channels - 1) / 4); + std::vector indices(n_textures); + jit_cuda_tex_get_indices(m_handle, indices.data()); + for (uint32_t i = 0; i < n_textures; i++) { + uint64_t new_index = fn(payload, indices[i], "", ""); + if (new_index != indices[i]) + jit_raise("A texture was changed by traversing it. This is " + "not supported!"); + } + } + } }; NAMESPACE_END(drjit) diff --git a/include/drjit/traversable_base.h b/include/drjit/traversable_base.h new file mode 100644 index 000000000..3e9563659 --- /dev/null +++ b/include/drjit/traversable_base.h @@ -0,0 +1,254 @@ +#pragma once + +#include "fwd.h" +#include +#include +#include +#include +#include +#include + +NAMESPACE_BEGIN(drjit) + +NAMESPACE_BEGIN(detail) +/** + * \brief The callback used to traverse all jit arrays of a C++ object such as a + * Mitsuba scene. + * + * \param payload: + * To wrap closures, a payload can be provided to the ``traverse_1_cb_ro`` + * function, that is passed to the callback. + * + * \param index: + * A non-owning index of the traversed jit array. + * + * \param variant: + * If a ``JitArray`` has the attribute ``IsClass`` it is referring to a + * drjit class. When such a variable is traversed, the ``variant`` and + * ``domain`` string of its ``CallSupport`` is provided to the callback + * using this argument. Otherwise the string is equal to "". + * + * \param domain: + * The domain of the ``CallSupport`` when traversing a class variable. + */ +using traverse_callback_ro = void (*)(void *payload, uint64_t index, + const char *variant, const char *domain); +/** + * \brief The callback used to traverse and modify all jit arrays of a C++ + * object such as a Mitsuba scene. + * + * \param payload: + * To wrap closures, a payload can be provided to the ``traverse_1_cb_ro`` + * function, that is passed to the callback. + * + * \param index: + * A non-owning index of the traversed jit array. + * + * \param variant: + * If a ``JitArray`` has the attribute ``IsClass`` it is referring to a + * drjit class. When such a variable is traversed, the ``variant`` and + * ``domain`` string of its ``CallSupport`` is provided to the callback + * using this argument. Otherwise the string is equal to "". + * + * \param domain: + * The domain of the ``CallSupport`` when traversing a class variable. + * + * \return + * The new index of the traversed variable. This index is borrowed, and + * should therefore be non-owning. + */ +using traverse_callback_rw = uint64_t (*)(void *payload, uint64_t index, + const char *variant, + const char *domain); + +inline void log_member_open(bool rw, const char *member) { + jit_log(LogLevel::Debug, "%s%s{", rw ? "rw " : "ro ", member); +} + +inline void log_member_close() { jit_log(LogLevel::Debug, "}"); } + +NAMESPACE_END(detail) + +/** + * \brief Interface for traversing C++ objects. + * + * This interface should be inherited by any class that can be added to the + * registry. We try to ensure this by wrapping the function ``jit_registry_put`` + * in the function ``drjit::registry_put`` that takes a ``TraversableBase`` for + * the pointer argument. + */ +struct DRJIT_EXTRA_EXPORT TraversableBase : public nanobind::intrusive_base { + /** + * \brief Traverse all jit arrays in this c++ object. For every jit + * variable, the callback should be called, with the provided payload + * pointer. + * + * \param payload: + * A pointer to a payload struct. The callback ``cb`` is called with this + * pointer. + * + * \param cb: + * A function pointer, that is called with the ``payload`` pointer, the + * index of the jit variable, and optionally the domain and variant of a + * ``Class`` variable. + */ + virtual void traverse_1_cb_ro(void *payload, + detail::traverse_callback_ro cb) const = 0; + + /** + * \brief Traverse all jit arrays in this c++ object, and assign the output of the + * callback to them. For every jit variable, the callback should be called, + * with the provided payload pointer. + * + * \param payload: + * A pointer to a payload struct. The callback ``cb`` is called with this + * pointer. + * + * \param cb: + * A function pointer, that is called with the ``payload`` pointer, the + * index of the jit variable, and optionally the domain and variant of a + * ``Class`` variable. The resulting index of calling this function + * pointer will be assigned to the traversed variable. The return value + * of the is borrowed from when overwriting assigning the traversed + * variable. + */ + virtual void traverse_1_cb_rw(void *payload, + detail::traverse_callback_rw cb) = 0; +}; + +/** + * \brief Macro for generating call to \c traverse_1_fn_ro for a class member. + * + * This is only a utility macro, for the DR_TRAVERSE_CB_RO macro. It can only be + * used in a context, where the ``payload`` and ``fn`` variables are present. + */ +#define DR_TRAVERSE_MEMBER_RO(member) \ + drjit::detail::log_member_open(false, #member); \ + drjit::traverse_1_fn_ro(member, payload, fn); \ + drjit::detail::log_member_close(); + +/** + * \brief Macro for generating call to \c traverse_1_fn_rw for a class member. + * + * This is only a utility macro, for the DR_TRAVERSE_CB_RW macro. It can only be + * used in a context, where the ``payload`` and ``fn`` variables are present. + */ +#define DR_TRAVERSE_MEMBER_RW(member) \ + drjit::detail::log_member_open(true, #member); \ + drjit::traverse_1_fn_rw(member, payload, fn); \ + drjit::detail::log_member_close(); + +/** + * \brief Macro, generating the implementation of the ``traverse_1_cb_ro`` + * method. + * + * The first argument should be the base class, from which the current class + * inherits. The other arguments should list all members of that class, which + * are supposed to be read only traversable. + */ +#define DR_TRAVERSE_CB_RO(Base, ...) \ + void traverse_1_cb_ro(void *payload, \ + drjit::detail::traverse_callback_ro fn) \ + const override { \ + static_assert( \ + std::is_base_of>::value); \ + if constexpr (!std::is_same_v) \ + Base::traverse_1_cb_ro(payload, fn); \ + DRJIT_MAP(DR_TRAVERSE_MEMBER_RO, __VA_ARGS__) \ + } + +/** + * \breif Macro, generating the implementation of the ``traverse_1_cb_rw`` + * method. + * + * The first argument should be the base class, from which the current class + * inherits. The other arguments should list all members of that class, which + * are supposed to be read and write traversable. + */ +#define DR_TRAVERSE_CB_RW(Base, ...) \ + void traverse_1_cb_rw(void *payload, \ + drjit::detail::traverse_callback_rw fn) override { \ + static_assert( \ + std::is_base_of>::value); \ + if constexpr (!std::is_same_v) \ + Base::traverse_1_cb_rw(payload, fn); \ + DRJIT_MAP(DR_TRAVERSE_MEMBER_RW, __VA_ARGS__) \ + } + +/** + * \brief Macro, generating the both the implementations of the + * ``traverse_1_cb_ro`` and ``traverse_1_cb_rw`` methods. + * + * The first argument should be the base class, from which the current class + * inherits. The other arguments should list all members of that class, which + * are supposed to be read and write traversable. + */ +#define DR_TRAVERSE_CB(Base, ...) \ +public: \ + DR_TRAVERSE_CB_RO(Base, __VA_ARGS__) \ + DR_TRAVERSE_CB_RW(Base, __VA_ARGS__) + +/** + * \brief Macro, generating the implementations of ``traverse_1_cb_ro`` and + * ``traverse_1_cb_rw`` of a nanobind trampoline class. + * + * This macro should only be instantiated on trampoline classes, that serve as + * the base class for derived types in Python. Adding this macro to a trampoline + * class, allows for the automatic traversal of all python members in any + * derived python class. + */ +#define DR_TRAMPOLINE_TRAVERSE_CB(Base) \ +public: \ + void traverse_1_cb_ro(void *payload, \ + drjit::detail::traverse_callback_ro fn) \ + const override { \ + DRJIT_MARK_USED(payload); \ + DRJIT_MARK_USED(fn); \ + if constexpr (!std ::is_same_v) \ + Base::traverse_1_cb_ro(payload, fn); \ + drjit::traverse_py_cb_ro(this, payload, fn); \ + } \ + void traverse_1_cb_rw(void *payload, \ + drjit::detail::traverse_callback_rw fn) override { \ + DRJIT_MARK_USED(payload); \ + DRJIT_MARK_USED(fn); \ + if constexpr (!std ::is_same_v) \ + Base::traverse_1_cb_rw(payload, fn); \ + drjit::traverse_py_cb_rw(this, payload, fn); \ + } + +/** + * \brief Register a \c TraversableBase pointer with Dr.Jit's pointer registry + * + * This should be used instead of \c jit_registry_put, as it enforces the + * pointers to be of type \c TraversableBase, allowing for traversal of registry + * bound pointers. + * + * Dr.Jit provides a central registry that maps registered pointer values to + * low-valued 32-bit IDs. The main application is efficient virtual function + * dispatch via \ref jit_var_call(), through the registry could be used for + * other applications as well. + * + * This function registers the specified pointer \c ptr with the registry, + * returning the associated ID value, which is guaranteed to be unique within + * the specified domain identified by the \c (variant, domain) strings. + * The domain is normally an identifier that is associated with the "flavor" + * of the pointer (e.g. instances of a particular class), and which ensures + * that the returned ID values are as low as possible. + * + * Caution: for reasons of efficiency, the \c domain parameter is assumed to a + * static constant that will remain alive. The RTTI identifier + * typeid(MyClass).name() is a reasonable choice that satisfies this + * requirement. + * + * Raises an exception when ``ptr`` is ``nullptr``, or when it has already been + * registered with *any* domain. + */ +inline uint32_t registry_put(const char *variant, const char *domain, + TraversableBase *ptr) { + return jit_registry_put(variant, domain, (void *) ptr); +} + +NAMESPACE_END(drjit) diff --git a/src/extra/autodiff.cpp b/src/extra/autodiff.cpp index 11c5fda25..9ae57f095 100644 --- a/src/extra/autodiff.cpp +++ b/src/extra/autodiff.cpp @@ -43,6 +43,7 @@ */ #include "common.h" +#include #include #include #include @@ -1783,6 +1784,18 @@ void ad_scope_leave(bool process_postponed) { } } +void ad_scope_postponed(drjit::vector *dst) { + LocalState &ls = local_state; + std::vector &scopes = ls.scopes; + if (scopes.empty()) + ad_raise("ad_scope_leave(): scope underflow!"); + Scope &scope = scopes.back(); + + for (auto &er : scope.postponed) { + dst->push_back(er.target); + } +} + /// Check if gradient tracking is enabled for the given variable int ad_grad_enabled(Index index) { ADIndex ad_index = ::ad_index(index); diff --git a/src/python/CMakeLists.txt b/src/python/CMakeLists.txt index 9886323d8..5d97ee406 100644 --- a/src/python/CMakeLists.txt +++ b/src/python/CMakeLists.txt @@ -68,6 +68,7 @@ nanobind_add_module( reduce.h reduce.cpp apply.h apply.cpp eval.h eval.cpp + freeze.h freeze.cpp memop.h memop.cpp slice.h slice.cpp dlpack.h dlpack.cpp diff --git a/src/python/apply.cpp b/src/python/apply.cpp index c611e099b..ce05011dc 100644 --- a/src/python/apply.cpp +++ b/src/python/apply.cpp @@ -603,11 +603,12 @@ struct recursion_guard { ~recursion_guard() { recursion_level--; } }; -void TraverseCallback::operator()(uint64_t) { } +uint64_t TraverseCallback::operator()(uint64_t, const char *, const char *) { return 0; } void TraverseCallback::traverse_unknown(nb::handle) { } /// Invoke the given callback on leaf elements of the pytree 'h' -void traverse(const char *op, TraverseCallback &tc, nb::handle h) { +void traverse(const char *op, TraverseCallback &tc, nb::handle h, + bool rw) { nb::handle tp = h.type(); recursion_guard guard; @@ -622,30 +623,64 @@ void traverse(const char *op, TraverseCallback &tc, nb::handle h) { len = s.len(inst_ptr(h)); for (Py_ssize_t i = 0; i < len; ++i) - traverse(op, tc, nb::steal(s.item(h.ptr(), i))); + traverse(op, tc, nb::steal(s.item(h.ptr(), i)), rw); } else { tc(h); } } else if (tp.is(&PyTuple_Type)) { for (nb::handle h2 : nb::borrow(h)) - traverse(op, tc, h2); + traverse(op, tc, h2, rw); } else if (tp.is(&PyList_Type)) { for (nb::handle h2 : nb::borrow(h)) - traverse(op, tc, h2); + traverse(op, tc, h2, rw); } else if (tp.is(&PyDict_Type)) { for (nb::handle h2 : nb::borrow(h).values()) - traverse(op, tc, h2); + traverse(op, tc, h2, rw); } else { if (nb::dict ds = get_drjit_struct(tp); ds.is_valid()) { for (auto [k, v] : ds) - traverse(op, tc, nb::getattr(h, k)); + traverse(op, tc, nb::getattr(h, k), rw); } else if (nb::object df = get_dataclass_fields(tp); df.is_valid()) { for (nb::handle field : df) { nb::object k = field.attr(DR_STR(name)); - traverse(op, tc, nb::getattr(h, k)); + traverse(op, tc, nb::getattr(h, k), rw); } - } else if (nb::object cb = get_traverse_cb_ro(tp); cb.is_valid()) { - cb(h, nb::cpp_function([&](uint64_t index) { tc(index); })); + } else if (auto traversable = get_traversable_base(h); + traversable) { + struct Payload { + TraverseCallback &tc; + }; + Payload p{ tc }; + if (rw) { + traversable->traverse_1_cb_rw( + (void *) &p, + [](void *p, uint64_t index, const char *variant, + const char *domain) -> uint64_t { + Payload *payload = (Payload *) p; + uint64_t new_index = + payload->tc(index, variant, domain); + return new_index; + }); + } else { + traversable->traverse_1_cb_ro( + (void *) &p, + [](void *p, uint64_t index, const char *variant, + const char *domain) { + Payload *payload = (Payload *) p; + payload->tc(index, variant, domain); + }); + } + } else if (auto cb = get_traverse_cb_ro(tp); cb.is_valid() && !rw) { + cb(h, nb::cpp_function([&](uint64_t index, const char *variant, + const char *domain) { + tc(index, variant, domain); + })); + } else if (nb::object cb = get_traverse_cb_rw(tp); + cb.is_valid() && rw) { + cb(h, nb::cpp_function([&](uint64_t index, const char *variant, + const char *domain) { + return tc(index, variant, domain); + })); } else { tc.traverse_unknown(h); } @@ -903,7 +938,9 @@ nb::object transform(const char *op, TransformCallback &tc, nb::handle h) { } result = tp(**tmp); } else if (nb::object cb = get_traverse_cb_rw(tp); cb.is_valid()) { - cb(h, nb::cpp_function([&](uint64_t index) { return tc(index); })); + cb(h, + nb::cpp_function([&](uint64_t index, const char *, + const char *) { return tc(index); })); result = nb::borrow(h); } else if (!result.is_valid()) { result = tc.transform_unknown(h); diff --git a/src/python/apply.h b/src/python/apply.h index 8e67ee089..c30289640 100644 --- a/src/python/apply.h +++ b/src/python/apply.h @@ -57,7 +57,8 @@ struct TraverseCallback { // Type-erased form which is needed in some cases to traverse into opaque // C++ code. This one just gets called with Jit/AD variable indices, an // associated Python/ instance/type is not available. - virtual void operator()(uint64_t index); + virtual uint64_t operator()(uint64_t index, const char *variant = nullptr, + const char *domain = nullptr); // Traverse an unknown object virtual void traverse_unknown(nb::handle h); @@ -80,9 +81,16 @@ struct TransformCallback { /// Initialize 'h2' (already allocated) based on 'h1' virtual void operator()(nb::handle h1, nb::handle h2) = 0; - // Type-erased form which is needed in some cases to traverse into opaque - // C++ code. This one just gets called with Jit/AD variable indices, an - // associated Python/ instance/type is not available. + /** Type-erased form which is needed in some cases to traverse into opaque + * C++ code. This one just gets called with Jit/AD variable indices, an + * associated Python/ instance/type is not available. + * This can optionally return a non-owning jit_index, that will be assigned + * to the underlying variable if \c traverse is called with the \c rw + * argument set to \c true. This can be used to modify JIT variables of + * PyTrees and their C++ objects in-place. For example, when applying + * operations such as \c jit_var_schedule_force to every JIT variable in a + * PyTree. + */ virtual uint64_t operator()(uint64_t index); }; @@ -96,9 +104,27 @@ struct TransformPairCallback { virtual nb::object transform_unknown(nb::handle h1, nb::handle h2) const; }; -/// Invoke the given callback on leaf elements of the pytree 'h' -extern void traverse(const char *op, TraverseCallback &callback, - nb::handle h); +/** + * \brief Invoke the given callback on leaf elements of the pytree 'h', + * including JIT indices in c++ objects, inheriting from + * \c drjit::TraversableBase. + * + * \param op: + * Name of the operation that is performed, this will be used in the + * exceptions that might be raised during traversal. + * + * \param callback: + * The \c TraverseCallback, called for every Jit variable in the pytree. + * + * \param rw: + * Boolean, indicating if C++ objects should be traversed in read-write + * mode. If this is set to \c true, the result from the method + * \c operator()(uint64_t) of the callback will be assigned to the + * underlying variable. This does not change how Python objects are + * traversed. + */ +extern void traverse(const char *op, TraverseCallback &callback, nb::handle h, + bool rw = false); /// Parallel traversal of two compatible pytrees 'h1' and 'h2' extern void traverse_pair(const char *op, TraversePairCallback &callback, diff --git a/src/python/common.h b/src/python/common.h index 215e840ed..07fc5c854 100644 --- a/src/python/common.h +++ b/src/python/common.h @@ -88,6 +88,13 @@ inline nb::object get_dataclass_fields(nb::handle tp) { } return result; } +/// Return a pointer to the underlying C++ class if the Python object inherits +/// from TraversableBase or null otherwise +inline drjit::TraversableBase *get_traversable_base(nb::handle h) { + drjit::TraversableBase *result = nullptr; + nb::try_cast(h, result); + return result; +} /// Extract a read-only callback to traverse custom data structures inline nb::object get_traverse_cb_ro(nb::handle tp) { diff --git a/src/python/detail.cpp b/src/python/detail.cpp index ef9f27e94..9e7a90d90 100644 --- a/src/python/detail.cpp +++ b/src/python/detail.cpp @@ -111,13 +111,15 @@ void collect_indices(nb::handle h, dr::vector &indices, bool inc_ref) void operator()(nb::handle h) override { auto index_fn = supp(h.type()).index; if (index_fn) - operator()(index_fn(inst_ptr(h))); + operator()(index_fn(inst_ptr(h)), nullptr, nullptr); } - void operator()(uint64_t index) override { + uint64_t operator()(uint64_t index, const char *, + const char *) override { if (inc_ref) ad_var_inc_ref(index); result.push_back(index); + return 0; } }; @@ -281,6 +283,115 @@ bool leak_warnings() { return nb::leak_warnings() || jit_leak_warnings() || ad_leak_warnings(); } +// Have to wrap this in an unnamed namespace to prevent collisions with the +// other declaration of ``recursion_guard``. +namespace { +static int recursion_level = 0; + +// PyTrees could theoretically include cycles. Catch infinite recursion below +struct recursion_guard { + recursion_guard() { + if (++recursion_level >= 50) { + PyErr_SetString(PyExc_RecursionError, "runaway recursion detected"); + nb::raise_python_error(); + } + } + ~recursion_guard() { recursion_level--; } +}; +} // namespace + +/** + * \brief Traverses all variables of a python object. + * + * This function is used to traverse variables of python objects, inheriting + * from trampoline classes. This allows the user to freeze for example custom + * BSDFs, without having to declare its variables. + */ +void traverse_py_cb_ro_impl(nb::handle self, nb::callable c) { + recursion_guard guard; + + struct PyTraverseCallback : TraverseCallback { + void operator()(nb::handle h) override { + const ArraySupplement &s = supp(h.type()); + auto index_fn = s.index; + if (index_fn){ + if (s.is_class){ + auto variant = + nb::borrow(nb::getattr(h, "Variant")); + auto domain = + nb::borrow(nb::getattr(h, "Domain")); + operator()(index_fn(inst_ptr(h)), + variant.is_valid() ? variant.c_str() : "", + domain.is_valid() ? domain.c_str() : ""); + } + else + operator()(index_fn(inst_ptr(h)), "", ""); + } + } + uint64_t operator()(uint64_t index, const char *variant, + const char *domain) override { + m_callback(index, variant, domain); + return 0; + } + nb::callable m_callback; + + PyTraverseCallback(nb::callable c) : m_callback(c) {} + }; + + PyTraverseCallback traverse_cb(std::move(c)); + + auto dict = nb::borrow(nb::getattr(self, "__dict__")); + + for (auto value : dict.values()) { + traverse("traverse_py_cb_ro", traverse_cb, value); + } +} + +/** + * \brief Traverses all variables of a python object. + * + * This function is used to traverse variables of python objects, inheriting + * from trampoline classes. This allows the user to freeze for example custom + * BSDFs, without having to declare its variables. + */ +void traverse_py_cb_rw_impl(nb::handle self, nb::callable c) { + recursion_guard guard; + + struct PyTraverseCallback : TraverseCallback { + void operator()(nb::handle h) override { + const ArraySupplement &s = supp(h.type()); + auto index_fn = s.index; + if (index_fn){ + uint64_t new_index; + if (s.is_class) { + auto variant = + nb::borrow(nb::getattr(h, "Variant")); + auto domain = nb::borrow(nb::getattr(h, "Domain")); + new_index = operator()( + index_fn(inst_ptr(h)), + variant.is_valid() ? variant.c_str() : "", + domain.is_valid() ? domain.c_str() : ""); + } else + new_index = operator()(index_fn(inst_ptr(h)), "", ""); + s.reset_index(new_index, inst_ptr(h)); + } + } + uint64_t operator()(uint64_t index, const char *variant, const char *domain) override { + return nb::cast(m_callback(index, variant, domain)); + } + nb::callable m_callback; + + PyTraverseCallback(nb::callable c) : m_callback(c) {} + }; + + PyTraverseCallback traverse_cb(std::move(c)); + + auto dict = nb::borrow(nb::getattr(self, "__dict__")); + + for (auto value : dict.values()) { + traverse("traverse_py_cb_rw", traverse_cb, value, true); + } +} void export_detail(nb::module_ &) { nb::module_ d = nb::module_::import_("drjit.detail"); @@ -344,6 +455,8 @@ void export_detail(nb::module_ &) { d.def("leak_warnings", &leak_warnings, doc_leak_warnings); d.def("set_leak_warnings", &set_leak_warnings, doc_set_leak_warnings); + d.def("traverse_py_cb_ro", &traverse_py_cb_ro_impl); + d.def("traverse_py_cb_rw", traverse_py_cb_rw_impl); trace_func_handle = d.attr("trace_func"); } diff --git a/src/python/docstr.rst b/src/python/docstr.rst index 0046499ea..87b2e978d 100644 --- a/src/python/docstr.rst +++ b/src/python/docstr.rst @@ -6040,6 +6040,35 @@ Note that this information can also be queried in a more fine-grained manner (per variable) using the :py:attr:`drjit.ArrayBase.state` field. +.. topic:: JitFlag_KernelFreezing + + Enable recording and replay of functions annotated with :py:func:`freeze`. + + If KernelFreezing is enabled, all Dr.Jit operations executed in a function + annotated with :py:func:`freeze` are recorded during its first execution + and replayed without re-tracing on subsequent calls. + + If this flag is disabled, replay of previously frozen functions is disabled + as well. + +.. topic:: JitFlag_FreezingScope + + This flag is set to ``True`` when Dr.Jit is currently recording a frozen + function. The flag is automatically managed and should not be updated by + application code. + + User code may query this flag to conditionally optimize kernels for frozen + function recording, such as re-seeding the sampler, used for rendering. + +.. topic:: JitFlag_EnableObjectTraversal + + This flag is set to ``True`` when Dr.Jit is currently traversing + inputs and outputs of a frozen function. The flag is automatically managed + and should not be updated by application code. + + When enabled, traversal of objects such as the ``Scene`` or ``BSDFs`` in + mitsuba is enabled. + .. topic:: JitFlag_Default The default set of optimization flags consisting of diff --git a/src/python/freeze.cpp b/src/python/freeze.cpp new file mode 100644 index 000000000..421fb6ea1 --- /dev/null +++ b/src/python/freeze.cpp @@ -0,0 +1,1551 @@ +#include "freeze.h" +#include "apply.h" +#include "autodiff.h" +#include "base.h" +#include "common.h" +#include "reduce.h" +#include "listobject.h" +#include "object.h" +#include "pyerrors.h" +#include "shape.h" +#include "tupleobject.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +/** + * \brief Helper struct to profile and log frozen functions. + */ +struct ProfilerPhase { + std::string m_message; + ProfilerPhase(const char *message) : m_message(message) { + jit_log(LogLevel::Debug, "profiler start: %s", message); +#if defined(DRJIT_ENABLE_NVTX) + jit_profile_range_push(message); +#endif + } + + ProfilerPhase(const drjit::TraversableBase *traversable) { + int status; + char message[1024] = {0}; + const char *name = typeid(*traversable).name(); + snprintf(message, 1024, "traverse_cb %s", name); + + jit_log(LogLevel::Debug, "profiler start: %s", message); + jit_profile_range_push(message); + m_message = message; + } + + ~ProfilerPhase() { +#if defined(DRJIT_ENABLE_NVTX) + jit_profile_range_pop(); +#endif + jit_log(LogLevel::Debug, "profiler end: %s", m_message.c_str()); + } +}; + +struct ADScopeContext { + bool process_postponed; + ADScopeContext(drjit::ADScope type, size_t size, const uint64_t *indices, + int symbolic, bool process_postponed) + : process_postponed(process_postponed) { + ad_scope_enter(type, size, indices, symbolic); + } + ~ADScopeContext() { ad_scope_leave(process_postponed); } +}; + +struct scoped_set_flag { + uint32_t backup; + scoped_set_flag(JitFlag flag, bool enabled) : backup(jit_flags()) { + uint32_t flags = backup; + if (enabled) + flags |= (uint32_t)flag; + else + flags &= ~(uint32_t) flag; + + jit_set_flags(flags); + } + + ~scoped_set_flag() { + jit_set_flags(backup); + } +}; + +using namespace detail; + +bool Layout::operator==(const Layout &rhs) const { + if (((bool) this->type != (bool) rhs.type) || !(this->type.equal(rhs.type))) + return false; + + if (this->num != rhs.num) + return false; + + if (this->fields.size() != rhs.fields.size()) + return false; + + for (uint32_t i = 0; i < this->fields.size(); ++i) { + if (!(this->fields[i].equal(rhs.fields[i]))) + return false; + } + + if (this->index != rhs.index) + return false; + + if (this->flags != rhs.flags) + return false; + + if (this->literal != rhs.literal) + return false; + + if (this->vt != rhs.vt) + return false; + + if (((bool) this->py_object != (bool) rhs.py_object) || + !this->py_object.equal(rhs.py_object)) + return false; + + return true; +} + +bool VarLayout::operator==(const VarLayout &rhs) const { + if (this->vt != rhs.vt) + return false; + + if (this->vs != rhs.vs) + return false; + + if (this->flags != rhs.flags) + return false; + + if (this->size_index != rhs.size_index) + return false; + + return true; +} + +/** + * \brief Add a variant domain pair to be traversed using the registry. + * + * When traversing a jit variable, that references a pointer to a class, + * such as a BSDF or Shape in Mitsuba, we have to traverse all objects + * registered with that variant-domain pair in the registry. This function + * adds the variant-domain pair, deduplicating the domain. Whether a + * variable references a class is represented by it's ``IsClass`` const + * attribute. If the domain is an empty string (""), this function skips + * adding the variant-domain pair. + */ +void FlatVariables::add_domain(const char *variant, const char *domain) { + // Since it is not possible to pass nullptr strings to nanobind functions we + // assume, that a valid domain indicates a valid variant. If the variant is + // empty at the end of traversal, we know that no Class variable was + // traversed, and registry traversal is not necessary. + if (domain && variant && strcmp(domain, "") != 0) { + jit_log(LogLevel::Debug, "variant=%s, domain=%s", variant, domain); + + if (domains.empty()){ + this->variant = variant; + } + else if (this->variant != variant) + jit_raise("traverse(): Variant missmatch! All arguments to a " + "frozen function have to have the same variant. " + "Variant %s of a previos argument does not match " + "variant %s of this argument.", + this->variant.c_str(), variant); + + bool contains = false; + for (std::string &d : domains) { + if (d == domain) { + contains = true; + break; + } + } + if (!contains) + domains.push_back(domain); + } +} + +/** + * Adds a jit index to the flattened array, deduplicating it. + * This allows to check for aliasing conditions, where two variables + * actually refer to the same index. The function should only be called for + * scheduled non-literal variable indices. + */ +uint32_t FlatVariables::add_jit_index(uint32_t index) { + uint32_t next_slot = this->variables.size(); + auto result = this->index_to_slot.try_emplace(index, next_slot); + auto it = result.first; + bool inserted = result.second; + + if (inserted) { + this->variables.push_back(index); + // Borrow the variable + jit_var_inc_ref(index); + this->var_layout.emplace_back(); + return next_slot; + } else { + return it.value(); + } +} + +/** + * \brief Records information about jit variables, that have been traversed. + * + * After traversing the PyTree, collecting non-literal indices in + * ``variables`` and evaluating the collected indices, we can collect + * information about the underlying variables. This information is used in + * the key of the ``RecordingMap`` to determine which recording should be + * replayed or if the function has to be re-traced. This function iterates + * over the collected indices and collects that information. + */ +void FlatVariables::record_jit_variables() { + assert(variables.size() == var_layout.size()); + for (uint32_t i = 0; i < var_layout.size(); i++){ + uint32_t index = variables[i]; + VarLayout &layout = var_layout[i]; + + VarInfo info = jit_set_backend(index); + + if (backend == info.backend || this->backend == JitBackend::None) { + backend = info.backend; + } else { + jit_raise("freeze(): backend missmatch error (backend of this " + "variable %s does not match backend of others %s)!", + info.backend == JitBackend::CUDA ? "CUDA" : "LLVM", + backend == JitBackend::CUDA ? "CUDA" : "LLVM"); + } + + if (info.type == VarType::Pointer) { + // We do not support pointers as inputs. It might be possible with + // some extra handling, but they are never used directly. + jit_raise("Pointer inputs not supported!"); + } + + layout.vs = info.state; + layout.vt = info.type; + layout.size_index = this->add_size(info.size); + + if (info.state == VarState::Evaluated) { + // Special case, handling evaluated/opaque variables. + + layout.flags |= + (info.size == 1 ? (uint32_t) LayoutFlag::SingletonArray : 0); + layout.flags |= (info.unaligned ? (uint32_t) LayoutFlag::Unaligned : 0); + + } else { + jit_raise("collect(): found variable %u in unsupported state %u!", + index, (uint32_t) info.state); + } + } +} + +/** + * This function returns an index of an equivalence class for the variable + * size in the flattened variables. + * It uses a hashmap and vector to deduplicate sizes. + * + * This is necessary, to catch cases, where two variables had the same size + * when freezing a function and two different sizes when replaying. + * In that case one kernel would be recorded, that evaluates both variables. + * However, when replaying two kernels would have to be launched since the + * now differently sized variables cannot be evaluated by the same kernel. + */ +uint32_t FlatVariables::add_size(uint32_t size) { + uint32_t next_slot = this->sizes.size(); + auto result = this->size_to_slot.try_emplace(size, next_slot); + auto it = result.first; + bool inserted = result.second; + + if (inserted) { + this->sizes.push_back(size); + return next_slot; + } else { + return it.value(); + } +} + +/** + * Traverse the variable referenced by a jit index and add it to the flat + * variables. An optional type python type can be supplied if it is known. + * Depending on the ``TraverseContext::schedule_force`` the underlying + * variable is either scheduled (``jit_var_schedule``) or force scheduled + * (``jit_var_schedule_force``). If the variable after evaluation is a + * literal, it is directly recorded in the ``layout`` otherwise, it is added + * to the ``variables`` array, allowing the variables to be used when + * recording the frozen function. + */ +void FlatVariables::traverse_jit_index(uint32_t index, TraverseContext &ctx, + nb::handle tp) { + (void) ctx; + Layout &layout = this->layout.emplace_back(); + + if (tp) + layout.type = nb::borrow(tp); + + int rv = 0; + if (ctx.schedule_force) { + // Returns owning reference + index = jit_var_schedule_force(index, &rv); + } else { + // Schedule and create owning reference + rv = jit_var_schedule(index); + jit_var_inc_ref(index); + } + + VarInfo info = jit_set_backend(index); + if (info.state == VarState::Literal) { + // Special case, where the variable is a literal. This should not + // occur, as all literals are made opaque in beforehand, however it + // is nice to have a fallback. + layout.literal = info.literal; + // Store size in index variable, as this is not used for literals + layout.index = info.size; + layout.vt = info.type; + + layout.flags |= (uint32_t) LayoutFlag::Literal; + } else { + layout.index = this->add_jit_index(index); + } + jit_var_dec_ref(index); +} + +/** + * Construct a variable, given it's layout. + * This is the counterpart to `traverse_jit_index`. + * + * Optionally, the index of a variable can be provided that will be + * overwritten with the result of this function. In that case, the function + * will check for compatible variable types. + */ +uint32_t FlatVariables::construct_jit_index(uint32_t prev_index) { + Layout &layout = this->layout[layout_index++]; + + uint32_t index; + VarType vt; + if (layout.flags & (uint32_t) LayoutFlag::Literal) { + index = jit_var_literal(this->backend, layout.vt, &layout.literal, + layout.index); + vt = layout.vt; + } else { + VarLayout &var_layout = this->var_layout[layout.index]; + index = this->variables[layout.index]; + jit_log(LogLevel::Debug, " uses output[%u] = r%u", layout.index, + index); + + jit_var_inc_ref(index); + vt = var_layout.vt; + } + + if (prev_index) { + if (vt != (VarType) jit_var_type(prev_index)) + jit_fail("VarType missmatch %u != %u while assigning (r%u) " + "-> (r%u)!", + (uint32_t) vt, (uint32_t) jit_var_type(prev_index), + (uint32_t) prev_index, (uint32_t) index); + } + return index; +} + +/** + * Add an ad variable by it's index. Both the value and gradient are added + * to the flattened variables. If the ad index has been marked as postponed + * in the \c TraverseContext.postponed field, we mark the resulting layout + * with that flag. This will cause the gradient edges to be propagated when + * assigning to the input. The function takes an optional python-type if + * it is known. + */ +void FlatVariables::traverse_ad_index(uint64_t index, TraverseContext &ctx, + nb::handle tp) { + // NOTE: instead of emplacing a Layout representing the ad variable always, + // we only do so if the gradients have been enabled. We use this format, + // since most variables will not be ad enabled. The layout therefore has to + // be peeked in ``construct_ad_index`` before deciding if an ad or jit + // index should be constructed/assigned. + int grad_enabled = ad_grad_enabled(index); + if (grad_enabled) { + Layout &layout = this->layout.emplace_back(); + uint32_t ad_index = (uint32_t) (index >> 32); + + if (tp) + layout.type = nb::borrow(tp); + layout.num = 2; + // layout.vt = jit_var_type(index); + + // Set flags + layout.flags |= (uint32_t) LayoutFlag::GradEnabled; + // If the edge with this node as it's target has been postponed by + // the isolate gradient scope, it has been enqueued and we mark the + // ad variable as such. + if (ctx.postponed && ctx.postponed->contains(ad_index)) { + layout.flags |= (uint32_t) LayoutFlag::Postponed; + } + + traverse_jit_index((uint32_t) index, ctx, tp); + uint32_t grad = ad_grad(index); + traverse_jit_index(grad, ctx, tp); + jit_var_dec_ref(grad); + } else { + traverse_jit_index(index, ctx, tp); + } +} + +/** + * Construct/assign the variable index given a layout. + * This corresponds to `traverse_ad_index`. + * + * This function is also used for assignment to ad-variables. + * If a `prev_index` is provided, and it is an ad-variable the gradient and + * value of the flat variables will be applied to the ad variable, + * preserving the ad_idnex. + * + * It returns an owning reference. + */ +uint64_t FlatVariables::construct_ad_index(uint32_t shrink, + uint64_t prev_index) { + Layout &layout = this->layout[this->layout_index]; + + uint64_t index; + if ((layout.flags & (uint32_t) LayoutFlag::GradEnabled) != 0) { + Layout &layout = this->layout[this->layout_index++]; + bool postponed = (layout.flags & (uint32_t) LayoutFlag::Postponed); + + uint32_t val = construct_jit_index(prev_index); + uint32_t grad = construct_jit_index(prev_index); + + // Resize the gradient if it is a literal + if ((VarState) jit_var_state(grad) == VarState::Literal) { + uint32_t new_grad = jit_var_resize(grad, jit_var_size(val)); + jit_var_dec_ref(grad); + grad = new_grad; + } + + // If the prev_index variable is provided we assign the new value + // and gradient to the ad variable of that index instead of creating + // a new one. + uint32_t ad_index = (uint32_t) (prev_index >> 32); + if (ad_index) { + index = (((uint64_t) ad_index) << 32) | ((uint64_t) val); + ad_var_inc_ref(index); + } else + index = ad_var_new(val); + + jit_log(LogLevel::Debug, " -> ad_var r%zu", index); + jit_var_dec_ref(val); + + // Equivalent to set_grad + ad_clear_grad(index); + ad_accum_grad(index, grad); + jit_var_dec_ref(grad); + + // Variables, that have been postponed by the isolate gradient scope + // will be enqueued, which propagates their gradeint to previous + // functions. + if (ad_index && postponed) { + ad_enqueue(drjit::ADMode::Backward, index); + } + } else { + index = construct_jit_index(prev_index); + } + + if (shrink > 0) + index = ad_var_shrink(index, shrink); + + return index; +} + +/** + * Wrapper aground traverse_ad_index for a python handle. + */ +void FlatVariables::traverse_ad_var(nb::handle h, TraverseContext &ctx) { + auto s = supp(h.type()); + + if (s.is_class) { + auto variant = nb::borrow(nb::getattr(h, "Variant")); + auto domain = nb::borrow(nb::getattr(h, "Domain")); + add_domain(variant.c_str(), domain.c_str()); + } + + raise_if(s.index == nullptr, "freeze(): ArraySupplement index function " + "pointer is nullptr."); + + uint64_t index = s.index(inst_ptr(h)); + + this->traverse_ad_index(index, ctx, h.type()); +} + +/** + * Construct an ad variable given it's layout. + * This corresponds to `traverse_ad_var` + */ +nb::object FlatVariables::construct_ad_var(const Layout &layout, + uint32_t shrink) { + uint64_t index = construct_ad_index(shrink); + + auto result = nb::inst_alloc_zero(layout.type); + const ArraySupplement &s = supp(result.type()); + s.init_index(index, inst_ptr(result)); + nb::inst_mark_ready(result); + + // We have to release the reference, since assignment will borrow from + // it. + ad_var_dec_ref(index); + + return result; +} + +/** + * Assigns an ad variable. + * Corresponds to `traverse_ad_var`. + * This uses `construct_ad_index` to either construct a new ad variable or + * assign the value and gradient to an already existing one. + */ +void FlatVariables::assign_ad_var(Layout &layout, nb::handle dst) { + const ArraySupplement &s = supp(layout.type); + + uint64_t index; + if (s.index) { + // ``construct_ad_index`` is used for assignment + index = construct_ad_index(0, s.index(inst_ptr(dst))); + } else + index = construct_ad_index(); + + s.reset_index(index, inst_ptr(dst)); + jit_log(LogLevel::Debug, "index=%zu, grad_enabled=%u, ad_grad_enabled=%u", + index, grad_enabled(dst), ad_grad_enabled(index)); + + // Release reference, since ``construct_ad_index`` returns owning + // reference and ``s.reset_index`` borrows from it. + ad_var_dec_ref(index); +} + +/** + * Traverse a c++ tree using it's `traverse_1_cb_ro` callback. + */ +void FlatVariables::traverse_cb(const drjit::TraversableBase *traversable, + TraverseContext &ctx, nb::object type) { + // ProfilerPhase profiler(traversable); + + uint32_t layout_index = this->layout.size(); + Layout &layout = this->layout.emplace_back(); + layout.type = nb::borrow(type); + + struct Payload{ + TraverseContext &ctx; + FlatVariables *flat_variables = nullptr; + uint32_t num_fields = 0; + }; + + Payload p{ctx, this, 0}; + + traversable->traverse_1_cb_ro((void*) & p, + [](void *p, uint64_t index, const char *variant, const char *domain) { + if (!index) + return; + Payload *payload = (Payload *)p; + payload->flat_variables->add_domain(variant, domain); + payload->flat_variables->traverse_ad_index(index, payload->ctx); + payload->num_fields++; + }); + + this->layout[layout_index].num = p.num_fields; +} + +/** + * Helper function, used to assign a callback variable. + * + * \param tmp + * This vector is populated with the indices to variables that have been + * constructed. It is required to release the references, since the + * references created by `construct_ad_index` are owning and they are + * borrowed after the callback returns. + */ +uint64_t FlatVariables::assign_cb_internal(uint64_t index, + index64_vector &tmp) { + if (!index) + return index; + + uint64_t new_index = this->construct_ad_index(0, index); + + tmp.push_back_steal(new_index); + return new_index; +} + +/** + * Assigns variables using it's `traverse_cb_rw` callback. + * This corresponds to `traverse_cb`. + */ +void FlatVariables::assign_cb(drjit::TraversableBase *traversable) { + Layout &layout = this->layout[layout_index++]; + + + struct Payload{ + FlatVariables *flat_variables = nullptr; + Layout &layout; + index64_vector tmp; + uint32_t field_counter = 0; + }; + Payload p{ this, layout, index64_vector(), 0 }; + traversable->traverse_1_cb_rw( + (void *) &p, [](void *p, uint64_t index, const char *, const char *) { + if (!index) + return index; + Payload *payload = (Payload *) p; + if (payload->field_counter >= payload->layout.num) + jit_raise("While traversing an object " + "for assigning inputs, the number of variables to " + "assign (>%u) did not match the number of variables " + "traversed when recording (%u)!", + payload->field_counter, payload->layout.num); + payload->field_counter++; + return payload->flat_variables->assign_cb_internal(index, payload->tmp); + }); + + if (p.field_counter != layout.num) + jit_raise("While traversing and object for assigning inputs, the " + "number of variables to assign did not match the number " + "of variables traversed when recording!"); +} + +/** + * Traverses a PyTree in DFS order, and records it's layout in the + * `layout` vector. + * + * When hitting a drjit primitive type, it calls the + * `traverse_dr_var` method, which will add their indices to the + * `flat_variables` vector. The collect method will also record metadata + * about the drjit variable in the layout. Therefore, the layout can be used + * as an identifier to the recording of the frozen function. + */ +void FlatVariables::traverse(nb::handle h, TraverseContext &ctx) { + recursion_guard guard(this); + + scoped_set_flag traverse_scope(JitFlag::EnableObjectTraversal, true); + + ProfilerPhase profiler("traverse"); + nb::handle tp = h.type(); + + auto tp_name = nb::type_name(tp).c_str(); + jit_log(LogLevel::Debug, "FlatVariables::traverse(): %s {", tp_name); + + try { + uint32_t layout_index = this->layout.size(); + Layout &layout = this->layout.emplace_back(); + layout.type = nb::borrow(tp); + if (is_drjit_type(tp)) { + const ArraySupplement &s = supp(tp); + if (s.is_tensor) { + nb::handle array = s.tensor_array(h.ptr()); + + auto full_shape = nb::borrow(shape(h)); + + // Instead of adding the whole shape of a tensor to the key, we + // only add the inner part, not containing dimension 0. When + // indexing into a tensor, this is the only dimension that is + // not used in the index calculation. When constructing a tensor + // this dimension is reconstructed from the width of the + // underlying array. + + nb::list inner_shape; + if (full_shape.size() > 0) + for (uint32_t i = 1; i < full_shape.size(); i++) { + inner_shape.append(full_shape[i]); + } + + layout.py_object = nb::tuple(inner_shape); + + traverse(nb::steal(array), ctx); + } else if (s.ndim != 1) { + Py_ssize_t len = s.shape[0]; + if (len == DRJIT_DYNAMIC) + len = s.len(inst_ptr(h)); + + layout.num = len; + + for (Py_ssize_t i = 0; i < len; ++i) + traverse(nb::steal(s.item(h.ptr(), i)), ctx); + } else { + traverse_ad_var(h, ctx); + } + } else if (tp.is(&PyTuple_Type)) { + nb::tuple tuple = nb::borrow(h); + + layout.num = tuple.size(); + + for (nb::handle h2 : tuple) { + traverse(h2, ctx); + } + } else if (tp.is(&PyList_Type)) { + nb::list list = nb::borrow(h); + + layout.num = list.size(); + + for (nb::handle h2 : list) { + traverse(h2, ctx); + } + } else if (tp.is(&PyDict_Type)) { + nb::dict dict = nb::borrow(h); + + layout.num = dict.size(); + layout.fields.reserve(layout.num); + for (auto k : dict.keys()) { + layout.fields.push_back(nb::borrow(k)); + } + + for (auto [k, v] : dict) { + traverse(v, ctx); + } + } else if (nb::dict ds = get_drjit_struct(tp); ds.is_valid()) { + + layout.num = ds.size(); + layout.fields.reserve(layout.num); + for (auto k : ds.keys()) { + layout.fields.push_back(nb::borrow(k)); + } + + for (auto [k, v] : ds) { + traverse(nb::getattr(h, k), ctx); + } + } else if (nb::object df = get_dataclass_fields(tp); df.is_valid()) { + + for (auto field : df) { + nb::object k = field.attr(DR_STR(name)); + layout.fields.push_back(nb::borrow(k)); + } + layout.num = layout.fields.size(); + + for (nb::handle field : df) { + nb::object k = field.attr(DR_STR(name)); + traverse(nb::getattr(h, k), ctx); + } + } else if (auto traversable = get_traversable_base(h); traversable) { + traverse_cb(traversable, ctx, nb::borrow(tp)); + } else if (auto cb = get_traverse_cb_ro(tp); cb.is_valid()) { + ProfilerPhase profiler("traverse cb"); + + uint32_t num_fields = 0; + + // Traverse the opaque C++ object + cb(h, nb::cpp_function([&](uint64_t index, const char *variant, + const char *domain) { + if (!index) + return; + add_domain(variant, domain); + num_fields++; + this->traverse_ad_index(index, ctx, nb::none()); + return; + })); + + // Update layout number of fields + this->layout[layout_index].num = num_fields; + } else { + jit_log(LogLevel::Info, + "traverse(): You passed a value of type %s to a frozen " + "function, it could not be converted to a Dr.Jit type. " + "Changing this value in future calls to the frozen " + "function will cause it to be re-traced.", + nb::str(tp).c_str()); + + layout.py_object = nb::borrow(h); + } + } catch (nb::python_error &e) { + nb::raise_from(e, PyExc_RuntimeError, + "FlatVariables::traverse(): error encountered while " + "processing an argument of type '%U' (see above).", + nb::type_name(tp).ptr()); + } catch (const std::exception &e) { + nb::chain_error(PyExc_RuntimeError, + "FlatVariables::traverse(): error encountered " + "while processing an argument of type '%U': %s", + nb::type_name(tp).ptr(), e.what()); + nb::raise_python_error(); + } + + jit_log(LogLevel::Debug, "}"); +} + +/** + * This is the counterpart to the ``traverse`` method, used to construct the + * output of a frozen function. Given a layout vector and flat_variables, it + * re-constructs the PyTree. + */ +nb::object FlatVariables::construct() { + recursion_guard guard(this); + + if (this->layout.size() == 0) { + return nb::none(); + } + + const Layout &layout = this->layout[layout_index++]; + + auto tp_name = nb::type_name(layout.type).c_str(); + jit_log(LogLevel::Debug, "FlatVariables::construct(): %s {", tp_name); + + if (layout.type.is(nb::none().type())) { + return nb::none(); + } + try { + if (is_drjit_type(layout.type)) { + const ArraySupplement &s = supp(layout.type); + if (s.is_tensor) { + nb::object array = construct(); + + // Reconstruct the full shape from the inner part, stored in the + // layout and the width of the underlying array. + auto inner_shape = nb::borrow(layout.py_object); + auto first_dim = prod(shape(array), nb::none()) + .floor_div(prod(inner_shape, nb::none())); + + nb::list full_shape; + full_shape.append(first_dim); + for (uint32_t i = 0; i < inner_shape.size(); i++) { + full_shape.append(inner_shape[i]); + } + + nb::object tensor = layout.type(array, nb::tuple(full_shape)); + return tensor; + } else if (s.ndim != 1) { + auto result = nb::inst_alloc_zero(layout.type); + dr::ArrayBase *p = inst_ptr(result); + size_t size = s.shape[0]; + if (size == DRJIT_DYNAMIC) { + size = s.len(p); + s.init(size, p); + } + for (size_t i = 0; i < size; ++i) { + result[i] = construct(); + } + nb::inst_mark_ready(result); + return result; + } else { + return construct_ad_var(layout); + } + } else if (layout.type.is(&PyTuple_Type)) { + nb::list list; + for (uint32_t i = 0; i < layout.num; ++i) { + list.append(construct()); + } + return nb::tuple(list); + } else if (layout.type.is(&PyList_Type)) { + nb::list list; + for (uint32_t i = 0; i < layout.num; ++i) { + list.append(construct()); + } + return std::move(list); + } else if (layout.type.is(&PyDict_Type)) { + nb::dict dict; + for (auto k : layout.fields) { + dict[k] = construct(); + } + return std::move(dict); + } else if (nb::dict ds = get_drjit_struct(layout.type); ds.is_valid()) { + nb::object tmp = layout.type(); + // TODO: validation against `ds` + for (auto k : layout.fields) { + nb::setattr(tmp, k, construct()); + } + return tmp; + } else if (nb::object df = get_dataclass_fields(layout.type); + df.is_valid()) { + nb::dict dict; + for (auto k : layout.fields) { + dict[k] = construct(); + } + return layout.type(**dict); + } else if (layout.py_object) { + return layout.py_object; + } else { + nb::raise("Tried to construct a variable of type %s that is not " + "constructable!", + nb::type_name(layout.type).c_str()); + } + } catch (nb::python_error &e) { + nb::raise_from(e, PyExc_RuntimeError, + "FlatVariables::construct(): error encountered while " + "processing an argument of type '%U' (see above).", + nb::type_name(layout.type).ptr()); + } catch (const std::exception &e) { + nb::chain_error(PyExc_RuntimeError, + "FlatVariables::construct(): error encountered " + "while processing an argument of type '%U': %s", + nb::type_name(layout.type).ptr(), e.what()); + nb::raise_python_error(); + } + + jit_log(LogLevel::Debug, "}"); +} + +/** + * Assigns the flattened variables to an already existing PyTree. + * This is used when input variables have changed. + */ +void FlatVariables::assign(nb::handle dst) { + recursion_guard guard(this); + scoped_set_flag traverse_scope(JitFlag::EnableObjectTraversal, true); + + nb::handle tp = dst.type(); + Layout &layout = this->layout[layout_index++]; + + jit_log(LogLevel::Debug, "FlatVariables::assign(): %s with %s {", + nb::type_name(tp).c_str(), nb::type_name(layout.type).c_str()); + + if (!layout.type.equal(tp)) + jit_fail( + "Type missmatch! Type of the object when recording (%s) does not " + "match type of object that is assigned (%s).", + nb::type_name(tp).c_str(), nb::type_name(layout.type).c_str()); + + try { + if (is_drjit_type(tp)) { + const ArraySupplement &s = supp(tp); + + if (s.is_tensor) { + nb::handle array = s.tensor_array(dst.ptr()); + assign(nb::steal(array)); + } else if (s.ndim != 1) { + Py_ssize_t len = s.shape[0]; + if (len == DRJIT_DYNAMIC) + len = s.len(inst_ptr(dst)); + + for (Py_ssize_t i = 0; i < len; ++i) + assign(dst[i]); + } else { + assign_ad_var(layout, dst); + } + } else if (tp.is(&PyTuple_Type)) { + nb::tuple tuple = nb::borrow(dst); + raise_if( + tuple.size() != layout.num, + "The number of objects in this tuple changed from %u to %u " + "while recording the function.", + layout.num, (uint32_t) tuple.size()); + + for (nb::handle h2 : tuple) + assign(h2); + } else if (tp.is(&PyList_Type)) { + nb::list list = nb::borrow(dst); + raise_if(list.size() != layout.num, + "The number of objects in this list changed from %u to %u " + "while recording the function.", + layout.num, (uint32_t) list.size()); + + for (nb::handle h2 : list) + assign(h2); + } else if (tp.is(&PyDict_Type)) { + nb::dict dict = nb::borrow(dst); + for (auto &k : layout.fields) { + if (dict.contains(&k)) + assign(dict[k]); + else + dst[k] = construct(); + } + } else if (nb::dict ds = get_drjit_struct(dst); ds.is_valid()) { + for (auto &k : layout.fields) { + if (nb::hasattr(dst, k)) + assign(nb::getattr(dst, k)); + else + nb::setattr(dst, k, construct()); + } + } else if (nb::object df = get_dataclass_fields(tp); df.is_valid()) { + for (auto k : layout.fields) { + if (nb::hasattr(dst, k)) + assign(nb::getattr(dst, k)); + else + nb::setattr(dst, k, construct()); + } + } else if (auto traversable = get_traversable_base(dst); traversable) { + assign_cb(traversable); + } else if (nb::object cb = get_traverse_cb_rw(tp); cb.is_valid()) { + index64_vector tmp; + uint32_t num_fields = 0; + + cb(dst, nb::cpp_function([&](uint64_t index, const char *, + const char *) { + if (!index) + return index; + jit_log(LogLevel::Debug, + "assign(): traverse_cb[%u] was a%u r%u", num_fields, + (uint32_t) (index >> 32), (uint32_t) index); + num_fields++; + if (num_fields > layout.num) + jit_raise( + "While traversing the object of type %s " + "for assigning inputs, the number of variables " + "to assign (>%u) did not match the number of " + "variables traversed when recording(%u)!", + nb::str(tp).c_str(), num_fields, layout.num); + return assign_cb_internal(index, tmp); + })); + if (num_fields != layout.num) + jit_raise("While traversing the object of type %s " + "for assigning inputs, the number of variables " + "to assign did not match the number of variables " + "traversed when recording!", + nb::str(tp).c_str()); + } else { + } + } catch (nb::python_error &e) { + nb::raise_from(e, PyExc_RuntimeError, + "FlatVariables::assign(): error encountered while " + "processing an argument " + "of type '%U' (see above).", + nb::type_name(tp).ptr()); + } catch (const std::exception &e) { + nb::chain_error(PyExc_RuntimeError, + "FlatVariables::assign(): error encountered " + "while processing an argument " + "of type '%U': %s", + nb::type_name(tp).ptr(), e.what()); + nb::raise_python_error(); + } + + jit_log(LogLevel::Debug, "}"); +} + +/** + * First traverses the PyTree, then the registry. This ensures that + * additional data to vcalls is tracked correctly. + */ +void FlatVariables::traverse_with_registry(nb::handle h, TraverseContext &ctx) { + scoped_set_flag traverse_scope(JitFlag::EnableObjectTraversal, true); + + // Traverse the handle + traverse(h, ctx); + + // Traverse the registry (if a class variable was traversed) + if (!domains.empty()) { + ProfilerPhase profiler("traverse_registry"); + uint32_t layout_index = this->layout.size(); + Layout &layout = this->layout.emplace_back(); + layout.type = nb::borrow(nb::none()); + + uint32_t num_fields = 0; + + jit_log(LogLevel::Debug, "registry{"); + + std::vector registry_pointers; + for (std::string &domain : domains) { + uint32_t registry_bound = + jit_registry_id_bound(variant.c_str(), domain.c_str()); + uint32_t offset = registry_pointers.size(); + registry_pointers.resize(registry_pointers.size() + registry_bound, nullptr); + jit_registry_get_pointers(variant.c_str(), domain.c_str(), + ®istry_pointers[offset]); + } + + jit_log(LogLevel::Debug, "registry_bound=%u", registry_pointers.size()); + jit_log(LogLevel::Debug, "layout_index=%u", this->layout.size()); + for (void *ptr : registry_pointers) { + jit_log(LogLevel::Debug, "ptr=%p", ptr); + if (!ptr) + continue; + + // WARN: very unsafe cast! + // We assume, that any object added to the registry inherits from + // TraversableBase. This is ensured by the signature of the + // ``drjit::registry_put`` function. + auto traversable = (drjit::TraversableBase *) ptr; + auto self = traversable->self_py(); + + if (self) + traverse(self, ctx); + else + traverse_cb(traversable, ctx); + + num_fields++; + } + jit_log(LogLevel::Debug, "}"); + + this->layout[layout_index].num = num_fields; + } +} + +/** + * First assigns the registry and then the PyTree. + * Corresponds to `traverse_with_registry`. + */ +void FlatVariables::assign_with_registry(nb::handle dst) { + scoped_set_flag traverse_scope(JitFlag::EnableObjectTraversal, true); + + // Assign the handle + assign(dst); + + // Assign registry (if a class variable was traversed) + if (!domains.empty()) { + Layout &layout = this->layout[layout_index++]; + uint32_t num_fields = 0; + + jit_log(LogLevel::Debug, "registry{"); + + std::vector registry_pointers; + for (std::string &domain : domains) { + uint32_t registry_bound = + jit_registry_id_bound(variant.c_str(), domain.c_str()); + uint32_t offset = registry_pointers.size(); + registry_pointers.resize(registry_pointers.size() + registry_bound, nullptr); + jit_registry_get_pointers(variant.c_str(), domain.c_str(), + ®istry_pointers[offset]); + } + + jit_log(LogLevel::Debug, "registry_bound=%u", registry_pointers.size()); + jit_log(LogLevel::Debug, "layout_index=%u", this->layout_index); + for (void *ptr : registry_pointers) { + jit_log(LogLevel::Debug, "ptr=%p", ptr); + if (!ptr) + continue; + + // WARN: very unsafe cast! + // We assume, that any object added to the registry inherits from + // TraversableBase. This is ensured by the signature of the + // ``drjit::registry_put`` function. + auto traversable = (drjit::TraversableBase *) ptr; + auto self = traversable->self_py(); + + if (self) + assign(self); + else + assign_cb(traversable); + + num_fields++; + } + jit_log(LogLevel::Debug, "}"); + } +} + +inline void hash_combine(size_t &seed, size_t value) { + /// From CityHash (https://github.com/google/cityhash) + const size_t mult = 0x9ddfea08eb382d69ull; + size_t a = (value ^ seed) * mult; + a ^= (a >> 47); + size_t b = (seed ^ a) * mult; + b ^= (b >> 47); + seed = b * mult; +} + +size_t +FlatVariablesHasher::operator()(const std::shared_ptr &key) const { + ProfilerPhase profiler("hash"); + // Hash the layout + // NOTE: string hashing seems to be less efficient + size_t hash = key->layout.size(); + for (const Layout &layout : key->layout) { + hash_combine(hash, layout.num); + hash_combine(hash, layout.fields.size()); + hash_combine(hash, (size_t) layout.flags); + hash_combine(hash, (size_t) layout.index); + hash_combine(hash, (size_t) layout.literal); + hash_combine(hash, (size_t) layout.vt); + if (layout.type) + hash_combine(hash, nb::hash(layout.type)); + if (layout.py_object) + hash_combine(hash, nb::hash(layout.py_object)); + for (auto &field : layout.fields) { + hash_combine(hash, nb::hash(field)); + } + } + + for (const VarLayout &layout : key->var_layout) { + hash_combine(hash, (size_t) layout.vt); + hash_combine(hash, (size_t) layout.vs); + hash_combine(hash, (size_t) layout.flags); + hash_combine(hash, (size_t) layout.size_index); + } + + hash_combine(hash, (size_t) key->flags); + + return hash; +} + +/* + * Record a function, given it's python input and flattened input. + */ +nb::object FunctionRecording::record(nb::callable func, + FrozenFunction *frozen_func, + nb::list input, + const FlatVariables &in_variables) { + ProfilerPhase profiler("record"); + JitBackend backend = in_variables.backend; + + frozen_func->recording_counter++; + if (frozen_func->recording_counter > frozen_func->warn_recording_count)jit_log( + LogLevel::Warn, + "The frozen function has been recorded %u times, this indicates a " + "problem with how the frozen function is being called. For " + "example, calling it with changing python values such as a " + "index.", + frozen_func->recording_counter); + + + jit_log(LogLevel::Info, + "Recording (n_inputs=%u):", in_variables.variables.size()); + jit_freeze_start(backend, in_variables.variables.data(), + in_variables.variables.size()); + + // Record the function + nb::object output; + { + ProfilerPhase profiler("function"); + output = func(*input[0], **input[1]); + } + + // Collect nodes, that have been postponed by the `Isolate` scope in a + // hash set. + // These are the targets of postponed edges, as the isolate gradient + // scope only handles backward mode differentiation. + // If they are, then we have to enqueue them when replaying the + // recording. + tsl::robin_set postponed; + { + drjit::vector postponed_vec; + ad_scope_postponed(&postponed_vec); + for (uint32_t index : postponed_vec) + postponed.insert(index); + } + + jit_log(LogLevel::Info, "Traversing output"); + { + ProfilerPhase profiler("traverse output"); + // Enter Resume scope, so we can track gradients + ADScopeContext ad_scope(drjit::ADScope::Resume, 0, nullptr, -1, false); + + TraverseContext ctx; + ctx.postponed = &postponed; + ctx.schedule_force = false; + out_variables.traverse(output, ctx); + ctx.schedule_force = true; + out_variables.traverse_with_registry(input, ctx); + + { // Evaluate the variables, scheduled when traversing + nb::gil_scoped_release guard; + jit_eval(); + } + + out_variables.record_jit_variables(); + } + + jit_freeze_pause(backend); + + if ((out_variables.variables.size() > 0 && + in_variables.variables.size() > 0) && + out_variables.backend != backend) { + Recording *recording = jit_freeze_stop(backend, nullptr, 0); + jit_freeze_destroy(recording); + + nb::raise( + "freeze(): backend missmatch error (backend %u of " + "output variables did not match backend %u of input variables)", + (uint32_t) out_variables.backend, (uint32_t) backend); + } + + // Exceptions, thrown by the recording functions will be recorded and + // re-thrown when calling ``jit_freeze_stop``. Since the output variables + // are borrowed, we have to release them in that case, and catch these + // exceptions. + try { + recording = jit_freeze_stop(backend, out_variables.variables.data(), + out_variables.variables.size()); + } catch (nb::python_error &e) { + out_variables.release(); + nb::raise_from(e, PyExc_RuntimeError, + "record(): error encountered while recording a function " + "(see above)."); + } catch (const std::exception &e) { + out_variables.release(); + nb::chain_error(PyExc_RuntimeError, "record(): %s", e.what()); + nb::raise_python_error(); + } + + jit_log(LogLevel::Info, "Recording done (n_outputs=%u)", + out_variables.variables.size()); + + // For catching input assignment mismatches, we assign the input and + // output + { + // Enter Resume scope, so we can track gradients + ADScopeContext ad_scope(drjit::ADScope::Resume, 0, nullptr, -1, false); + + out_variables.layout_index = 0; + jit_log(LogLevel::Debug, "Construct:"); + output = nb::borrow(out_variables.construct()); + // NOTE: temporarily disable this to not enqueue twice + out_variables.assign(input); + out_variables.layout_index = 0; + } + + // Traversal takes owning references, so here we need to release them. + out_variables.release(); + + return output; +} +/* + * Replays the recording. + * + * This constructs the output and re-assigns the input. + */ +nb::object FunctionRecording::replay(nb::callable func, + FrozenFunction *frozen_func, + nb::list input, + const FlatVariables &in_variables) { + ProfilerPhase profiler("replay"); + + jit_log(LogLevel::Info, "Replaying:"); + int dryrun_success; + { + ProfilerPhase profiler("dry run"); + dryrun_success = + jit_freeze_dry_run(recording, in_variables.variables.data()); + } + if (!dryrun_success) { + // Dry run has failed. Re-record the function. + jit_log(LogLevel::Info, "Dry run failed! re-recording"); + this->clear(); + try { + return this->record(func, frozen_func, input, in_variables); + } catch (nb::python_error &e) { + nb::raise_from(e, PyExc_RuntimeError, + "replay(): error encountered while re-recording a " + "function (see above)."); + } catch (const std::exception &e) { + jit_freeze_abort(in_variables.backend); + + nb::chain_error(PyExc_RuntimeError, "record(): %s", e.what()); + nb::raise_python_error(); + } + } else { + ProfilerPhase profiler("jit replay"); + nb::gil_scoped_release guard; + jit_freeze_replay(recording, in_variables.variables.data(), + out_variables.variables.data()); + } + jit_log(LogLevel::Info, "Replaying done:"); + + // Construct Output variables + nb::object output; + { + // Enter Resume scope, so we can track gradients + ADScopeContext ad_scope(drjit::ADScope::Resume, 0, nullptr, -1, false); + out_variables.layout_index = 0; + { + ProfilerPhase profiler("construct output"); + output = nb::borrow(out_variables.construct()); + } + { + ProfilerPhase profiler("assign input"); + out_variables.assign_with_registry(input); + } + } + + // out_variables is assigned by ``jit_record_replay``, which transfers + // ownership to this array. Therefore, we have to drop the variables + // afterwards. + out_variables.release(); + + return output; +} + +nb::object FrozenFunction::operator()(nb::args args, nb::kwargs kwargs) { + nb::object result; + { + // Enter Isolate grad scope, so that gradients are not propagated + // outside of the function scope. + ADScopeContext ad_scope(drjit::ADScope::Isolate, 0, nullptr, -1, true); + + // Kernel freezing can be enabled or disabled with the + // ``JitFlag::KernelFreezing``. Alternatively, when calling a frozen + // function from another one, we simply record the inner function. + if (!jit_flag(JitFlag::KernelFreezing) || + jit_flag(JitFlag::FreezingScope) || max_cache_size == 0) { + ProfilerPhase profiler("function"); + return func(*args, **kwargs); + } + + call_counter++; + + nb::list input; + input.append(args); + input.append(kwargs); + + auto in_variables = + std::make_shared(FlatVariables(in_heuristics)); + in_variables->backend = this->default_backend; + // Evaluate and traverse input variables (args and kwargs) + { + // Enter Resume scope, so we can track gradients + ADScopeContext ad_scope(drjit::ADScope::Resume, 0, nullptr, 0, + true); + + // Traverse input variables + ProfilerPhase profiler("traverse input"); + jit_log(LogLevel::Info, "freeze(): Traversing input"); + TraverseContext ctx; + ctx.schedule_force = true; + in_variables->traverse_with_registry(input, ctx); + + { // Evaluate the variables, scheduled when traversing + nb::gil_scoped_release guard; + jit_eval(); + } + + in_variables->record_jit_variables(); + } + + in_heuristics = in_heuristics.max(in_variables->heuristic()); + + raise_if(in_variables->backend == JitBackend::None, + "freeze(): Cannot infer backend without providing input " + "variable to frozen function!"); + + auto it = this->recordings.find(in_variables); + + // Evict the least recently used recording if the cache is "full" + if (max_cache_size > 0 && + recordings.size() >= (uint32_t) max_cache_size && + it == this->recordings.end()) { + + uint32_t lru_last_used = UINT32_MAX; + RecordingMap::iterator lru_it = recordings.begin(); + + for (auto it = recordings.begin(); it != recordings.end(); it++) { + auto &recording = it.value(); + if (recording->last_used < lru_last_used) { + lru_last_used = recording->last_used; + lru_it = it; + } + } + recordings.erase(lru_it); + + it = this->recordings.find(in_variables); + } + + if (it == this->recordings.end()) { + { + // TODO: single traverse + ADScopeContext ad_scope(drjit::ADScope::Resume, 0, nullptr, 0, + true); + in_variables->assign_with_registry(input); + } + + // FunctionRecording recording; + auto recording = std::make_unique(); + recording->last_used = call_counter - 1; + + try { + result = recording->record(func, this, input, *in_variables); + } catch (nb::python_error &e) { + in_variables->release(); + jit_freeze_abort(in_variables->backend); + nb::raise_from( + e, PyExc_RuntimeError, + "record(): error encountered while recording a frozen " + "function (see above)."); + } catch (const std::exception &e) { + in_variables->release(); + jit_freeze_abort(in_variables->backend); + + nb::chain_error(PyExc_RuntimeError, "record(): %s", e.what()); + nb::raise_python_error(); + }; + + in_variables->release(); + + this->prev_key = in_variables; + this->recordings.insert( + { std::move(in_variables), std::move(recording) }); + + } else { + FunctionRecording *recording = it.value().get(); + + recording->last_used = call_counter - 1; + + { + result = recording->replay(func, this, input, *in_variables); + } + + // Drop references to variables + in_variables->release(); + } + } + ad_traverse(drjit::ADMode::Backward, + (uint32_t) drjit::ADFlag::ClearVertices); + return result; +} + +void FrozenFunction::clear() { + recordings.clear(); + prev_key = std::make_shared(FlatVariables()); + recording_counter = 0; + call_counter = 0; +} + +/** + * This function inspects the content of the frozen function to detect reference + * cycles, that could lead to memory or type leaks. It can be called by the + * garbage collector by adding it to the ``type_slots`` of the + * ``FrozenFunction`` definition. + */ +int frozen_function_tp_traverse(PyObject *self, visitproc visit, void *arg) { + FrozenFunction *f = nb::inst_ptr(self); + + nb::handle func = nb::find(f->func); + Py_VISIT(func.ptr()); + + for (auto &it : f->recordings) { + for (auto &layout : it.first->layout) { + nb::handle type = nb::find(layout.type); + Py_VISIT(type.ptr()); + nb::handle object = nb::find(layout.py_object); + Py_VISIT(object.ptr()); + } + for (auto &layout : it.second->out_variables.layout) { + nb::handle type = nb::find(layout.type); + Py_VISIT(type.ptr()); + nb::handle object = nb::find(layout.py_object); + Py_VISIT(object.ptr()); + } + } + + return 0; +} + +/** + * This function releases the internal function of the ``FrozenFunction`` + * object. It is used by the garbage collector to "break" potential reference + * cycles, resulting from the frozen function being referenced in the closure of + * the wrapped variable. + */ +int frozen_function_clear(PyObject *self) { + FrozenFunction *f = nb::inst_ptr(self); + + f->func.release(); + + return 0; +} + +// Slot data structure referencing the above two functions +static PyType_Slot slots[] = { { Py_tp_traverse, + (void *) frozen_function_tp_traverse }, + { Py_tp_clear, (void *) frozen_function_clear }, + { 0, nullptr } }; + +void export_freeze(nb::module_ & /*m*/) { + + nb::module_ d = nb::module_::import_("drjit.detail"); + auto traversable_base = + nb::class_(d, "TraversableBase"); + nb::class_(d, "FrozenFunction", nb::type_slots(slots)) + .def(nb::init()) + .def_prop_ro( + "n_cached_recordings", + [](FrozenFunction &self) { return self.n_cached_recordings(); }) + .def_ro("n_recordings", &FrozenFunction::recording_counter) + .def("clear", &FrozenFunction::clear) + .def("__call__", &FrozenFunction::operator()); +} diff --git a/src/python/freeze.h b/src/python/freeze.h new file mode 100644 index 000000000..b02cf6438 --- /dev/null +++ b/src/python/freeze.h @@ -0,0 +1,529 @@ +/* + freeze.h -- Bindings for drjit.freeze() + + Dr.Jit: A Just-In-Time-Compiler for Differentiable Rendering + Copyright 2023, Realistic Graphics Lab, EPFL. + + All rights reserved. Use of this source code is governed by a + BSD-style license that can be found in the LICENSE.txt file. +*/ + +#pragma once + +#include "common.h" +#include +#include +#include +#include +#include +#include + +struct FrozenFunction; + +namespace detail { + +using index64_vector = drjit::detail::index64_vector; + +enum class LayoutFlag : uint32_t { + /// Whether this variable has size 1 + SingletonArray = (1 << 0), + /// Whether this variable is unaligned in memory + Unaligned = (1 << 1), + /// Whether this layout represents a literal variable + Literal = (1 << 2), + /// Whether this variable has gradients enabled + GradEnabled = (1 << 3), + /// Did this variable have gradient edges attached when recording, that + /// where postponed by the ``isolate_grad`` function? + Postponed = (1 << 4), +}; + +/// Stores information about python objects, such as their type, their number of +/// sub-elements or their field keys. This can be used to reconstruct a PyTree +/// from a flattened variable array. +struct Layout { + /// Number of members in this container. + /// Can be used to traverse the layout without knowing the type. + uint32_t num = 0; + /// Optional field identifiers of the container + /// for example: keys in dictionary + std::vector fields; + /// The index in the flat_variables array of this variable. + /// This can be used to determine aliasing. + uint32_t index = 0; + + /// Flags, storing information about variables and literals. + uint32_t flags = 0; + + /// The literal data + uint64_t literal = 0; + /// Optional drjit type of the variable + VarType vt = VarType::Void; + + /// If a non drjit type is passed as function arguments or result, we simply + /// cache it here. + /// TODO: possibly do the same for literals? + nb::object py_object; + + /// Nanobind type of the container/variable + nb::type_object type; + + bool operator==(const Layout &rhs) const; + + Layout() = default; + + Layout(const Layout &) = delete; + Layout &operator=(const Layout &) = delete; + + Layout(Layout &&) = default; + Layout &operator=(Layout &&) = default; +}; + +/** + * \brief Stores information about opaque variables. + * + * When traversing a PyTree, literal variables are stored directly and + * non-literal variables are first scheduled and their indices deduplicated and + * added to the ``FlatVariables::variables`` field. After calling ``jit_eval``, + * information about variables can be recorded using + * ``FlatVariables::record_jit_variables``. This struct stores that information + * per deduplicated variable. + */ +struct VarLayout{ + /// Optional drjit type of the variable + VarType vt = VarType::Void; + /// Optional evaluation state of the variable + VarState vs = VarState::Invalid; + /// Flags, storing information about variables + uint32_t flags = 0; + /// We have to track the condition, where two variables have the same size + /// during recording but don't when replaying. + /// Therefore we de-duplicate the size. + uint32_t size_index = 0; + + VarLayout() = default; + + VarLayout(const VarLayout &) = delete; + VarLayout &operator=(const VarLayout &) = delete; + + VarLayout(VarLayout &&) = default; + VarLayout &operator=(VarLayout &&) = default; + + bool operator==(const VarLayout &rhs) const; +}; + + +// Additional context required when traversing the inputs +struct TraverseContext { + /// Set of postponed ad nodes, used to mark inputs to functions. + const tsl::robin_set *postponed = nullptr; + bool schedule_force = false; +}; + +/** + * \brief A flattened representation of the PyTree. + * + * This struct stores a flattened representation of a PyTree as well a + * representation of it. It can therefore be used to either construct the PyTree + * as well as assign the variables to an existing PyTree. Furthermore, this + * struct can also be used as a key to the ``RecordingMap``, determining which + * recording should be used given an input to a frozen function. + * Information about the PyTree is stored in DFS Encoding. Every node of the + * tree is represented by a ``Layout`` element in the ``layout`` vector. + */ +struct FlatVariables { + + uint32_t flags = 0; + + // Index, used to iterate over the variables/layouts when constructing + // python objects + uint32_t layout_index = 0; + + /// The flattened and de-duplicated variable indices of the input/output to + /// a frozen function + std::vector variables; + /// Mapping from drjit jit index to index in flat variables. Used to + /// deduplicate jit indices. + tsl::robin_map index_to_slot; + + /// We have to track the condition, where two variables have the same size + /// during recording but don't when replaying. + /// Therefore we construct equivalence classes of sizes. + /// This vector represents the different sizes, encountered during + /// traversal. The algorithm used to "add" a size is the same as for adding + /// a variable index. + std::vector sizes; + /// Mapping from the size to its index in the ``sizes`` vector. This is used + /// to construct size equivalence classes (i.e. deduplicating sizes). + tsl::robin_map size_to_slot; + + /// This saves information about the type, size and fields of pytree + /// objects. The information is stored in DFS order. + std::vector layout; + /// Stores information about non-literal jit variables. + std::vector var_layout; + /// The collective backend for all input variables. It can be used to ensure + /// that all variables have the same backend. + JitBackend backend = JitBackend::None; + /// The variant, if any, used to traverse the registry. + std::string variant; + /// All domains (deduplicated), encountered while traversing the PyTree and + /// its C++ objects. This can be used to traverse the registry. We use a + /// vector instead of a hash set, since we expect the number of domains not + /// to exceed 100. + std::vector domains; + + uint32_t recursion_level = 0; + + struct recursion_guard { + FlatVariables *flat_variables; + recursion_guard(FlatVariables *flat_variables) + : flat_variables(flat_variables) { + if (++flat_variables->recursion_level >= 50) { + PyErr_SetString(PyExc_RecursionError, + "runaway recursion detected"); + nb::raise_python_error(); + } + } + ~recursion_guard() { flat_variables->recursion_level--; } +}; + + + /** + * Describes how many elements have to be pre-allocated for the ``layout``, + * ``index_to_slot`` and ``size_to_slot`` containers. + */ + struct Heuristic { + size_t layout = 0; + size_t index_to_slot = 0; + size_t size_to_slot = 0; + + Heuristic max(Heuristic rhs){ + return Heuristic{ + std::max(layout, rhs.layout), + std::max(index_to_slot, rhs.index_to_slot), + std::max(size_to_slot, rhs.size_to_slot), + }; + } + }; + + FlatVariables() {} + FlatVariables(Heuristic heuristic) { + layout.reserve(heuristic.layout); + index_to_slot.reserve(heuristic.index_to_slot); + size_to_slot.reserve(heuristic.size_to_slot); + } + + FlatVariables(const FlatVariables &) = delete; + FlatVariables &operator=(const FlatVariables &) = delete; + + FlatVariables(FlatVariables &&) = default; + FlatVariables &operator=(FlatVariables &&) = default; + + void clear() { + this->layout_index = 0; + this->variables.clear(); + this->index_to_slot.clear(); + this->layout.clear(); + this->backend = JitBackend::None; + } + /// Borrow all variables held by this struct. + void borrow() { + for (uint32_t &index : this->variables) + jit_var_inc_ref(index); + } + /// Release all variables held by this struct. + void release() { + for (uint32_t &index : this->variables) + jit_var_dec_ref(index); + } + + /** + * \brief Records information about jit variables, that have been traversed. + * + * After traversing the PyTree, collecting non-literal indices in + * ``variables`` and evaluating the collected indices, we can collect + * information about the underlying variables. This information is used in + * the key of the ``RecordingMap`` to determine which recording should be + * replayed or if the function has to be re-traced. This function iterates + * over the collected indices and collects that information. + */ + void record_jit_variables(); + + /** + * Returns a struct representing heuristics to pre-allocate memory for the + * layout, of the flat variables. + */ + Heuristic heuristic() { + return Heuristic{ + layout.size(), + index_to_slot.size(), + size_to_slot.size(), + }; + }; + + /** + * \brief Add a variant domain pair to be traversed using the registry. + * + * When traversing a jit variable, that references a pointer to a class, + * such as a BSDF or Shape in Mitsuba, we have to traverse all objects + * registered with that variant-domain pair in the registry. This function + * adds the variant-domain pair, deduplicating the domain. Whether a + * variable references a class is represented by it's ``IsClass`` const + * attribute. If the domain is an empty string (""), this function skips + * adding the variant-domain pair. + */ + void add_domain(const char *variant, const char *domain); + + /** + * Adds a jit index to the flattened array, deduplicating it. + * This allows to check for aliasing conditions, where two variables + * actually refer to the same index. The function should only be called for + * scheduled non-literal variable indices. + */ + uint32_t add_jit_index(uint32_t variable_index); + + /** + * This function returns an index into the ``sizes`` vector, representing an + * equivalence class of variable sizes. It uses a HashMap and vector to + * deduplicate sizes. + * + * This is necessary, to catch cases, where two variables had the same size + * when recording a function and two different sizes when replaying. + * In that case one kernel would be recorded, that evaluates both variables. + * However, when replaying two kernels would have to be launched since the + * now differently sized variables cannot be evaluated by the same kernel. + */ + uint32_t add_size(uint32_t size); + + /** + * Traverse the variable referenced by a jit index and add it to the flat + * variables. An optional type python type can be supplied if it is known. + * Depending on the ``TraverseContext::schedule_force`` the underlying + * variable is either scheduled (``jit_var_schedule``) or force scheduled + * (``jit_var_schedule_force``). If the variable after evaluation is a + * literal, it is directly recorded in the ``layout`` otherwise, it is added + * to the ``variables`` array, allowing the variables to be used when + * recording the frozen function. + */ + void traverse_jit_index(uint32_t index, TraverseContext &ctx, + nb::handle tp = nullptr); + /** + * Add an ad variable by it's index. Both the value and gradient are added + * to the flattened variables. If the ad index has been marked as postponed + * in the \c TraverseContext.postponed field, we mark the resulting layout + * with that flag. This will cause the gradient edges to be propagated when + * assigning to the input. The function takes an optional python-type if + * it is known. + */ + void traverse_ad_index(uint64_t index, TraverseContext &ctx, + nb::handle tp = nullptr); + + /** + * Wrapper aground traverse_ad_index for a python handle. + */ + void traverse_ad_var(nb::handle h, TraverseContext &ctx); + + /** + * Traverse a c++ tree using it's `traverse_1_cb_ro` callback. + */ + void traverse_cb(const drjit::TraversableBase *traversable, + TraverseContext &ctx, nb::object type = nb::none()); + + /** + * Traverses a PyTree in DFS order, and records it's layout in the + * `layout` vector. + * + * When hitting a drjit primitive type, it calls the + * `traverse_dr_var` method, which will add their indices to the + * `flat_variables` vector. The collect method will also record metadata + * about the drjit variable in the layout. Therefore, the layout can be used + * as an identifier to the recording of the frozen function. + */ + void traverse(nb::handle h, TraverseContext &ctx); + + /** + * First traverses the PyTree, then the registry. This ensures that + * additional data to vcalls is tracked correctly. + */ + void traverse_with_registry(nb::handle h, TraverseContext &ctx); + + /** + * Construct a variable, given it's layout. + * This is the counterpart to `traverse_jit_index`. + * + * Optionally, the index of a variable can be provided that will be + * overwritten with the result of this function. In that case, the function + * will check for compatible variable types. + */ + uint32_t construct_jit_index(uint32_t prev_index = 0); + + /** + * Construct/assign the variable index given a layout. + * This corresponds to `traverse_ad_index`> + * + * This function is also used for assignment to ad-variables. + * If a `prev_index` is provided, and it is an ad-variable the gradient and + * value of the flat variables will be applied to the ad variable, + * preserving the ad_idnex. + * + * It returns an owning reference. + */ + uint64_t construct_ad_index(uint32_t shrink = 0, + uint64_t prev_index = 0); + + /** + * Construct an ad variable given it's layout. + * This corresponds to `traverse_ad_var` + */ + nb::object construct_ad_var(const Layout &layout, uint32_t shrink = 0); + + /** + * This is the counterpart to the traverse method, used to construct the + * output of a frozen function. Given a layout vector and flat_variables, it + * re-constructs the PyTree. + */ + nb::object construct(); + + /** + * Assigns an ad variable. + * Corresponds to `traverse_ad_var`. + * This uses `construct_ad_index` to either construct a new ad variable or + * assign the value and gradient to an already existing one. + */ + void assign_ad_var(Layout &layout, nb::handle dst); + + /** + * Helper function, used to assign a callback variable. + * + * \param tmp + * This vector is populated with the indices to variables that have been + * constructed. It is required to release the references, since the + * references created by `construct_ad_index` are owning and they are + * borrowed after the callback returns. + */ + uint64_t assign_cb_internal(uint64_t index, index64_vector &tmp); + + /** + * Assigns variables using it's `traverse_cb_rw` callback. + * This corresponds to `traverse_cb`. + */ + void assign_cb(drjit::TraversableBase *traversable); + + /** + * Assigns the flattened variables to an already existing PyTree. + * This is used when input variables have changed. + */ + void assign(nb::handle dst); + + /** + * First assigns the registry and then the PyTree. + * Corresponds to `traverse_with_registry`. + */ + void assign_with_registry(nb::handle dst); + + bool operator==(const FlatVariables &rhs) const { + return this->layout == rhs.layout && + this->var_layout == rhs.var_layout && this->flags == rhs.flags; + } +}; + +/// Helper struct to hash input variables +struct FlatVariablesHasher { + size_t operator()(const std::shared_ptr &key) const; +}; + +/// Helper struct to compare input variables +struct FlatVariablesEqual{ + using is_transparent = void; + bool operator()(const std::shared_ptr &lhs, + const std::shared_ptr &rhs) const { + return *lhs.get() == *rhs.get(); + } +}; + +/** + * \brief A recording of a frozen function, recorded with a certain layout of + * input variables. + */ +struct FunctionRecording { + uint32_t last_used = 0; + Recording *recording = nullptr; + FlatVariables out_variables; + + FunctionRecording() : out_variables() {} + FunctionRecording(const FunctionRecording &) = delete; + FunctionRecording &operator=(const FunctionRecording &) = delete; + FunctionRecording(FunctionRecording &&) = default; + FunctionRecording &operator=(FunctionRecording &&) = default; + + ~FunctionRecording() { + if (this->recording) { + jit_freeze_destroy(this->recording); + } + this->recording = nullptr; + } + + void clear() { + if (this->recording) { + jit_freeze_destroy(this->recording); + } + this->recording = nullptr; + this->out_variables = FlatVariables(); + } + + /* + * Record a function, given it's python input and flattened input. + */ + nb::object record(nb::callable func, FrozenFunction *frozen_func, + nb::list input, const FlatVariables &in_variables); + /* + * Replays the recording. + * + * This constructs the output and re-assigns the input. + */ + nb::object replay(nb::callable func, FrozenFunction *frozen_func, + nb::list input, const FlatVariables &in_variables); +}; + +using RecordingMap = tsl::robin_map, + std::unique_ptr, + FlatVariablesHasher, FlatVariablesEqual>; + +} // namespace detail + +struct FrozenFunction { + nb::callable func; + + detail::RecordingMap recordings; + std::shared_ptr prev_key; + + uint32_t recording_counter = 0; + uint32_t call_counter = 0; + int max_cache_size = -1; + uint32_t warn_recording_count = 10; + JitBackend default_backend = JitBackend::None; + + detail::FlatVariables::Heuristic in_heuristics; + + FrozenFunction(nb::callable func, int max_cache_size = -1, + uint32_t warn_recording_count = 10, + JitBackend backend = JitBackend::None) + : func(func), max_cache_size(max_cache_size), + warn_recording_count(warn_recording_count), default_backend(backend) { + } + ~FrozenFunction() {} + + FrozenFunction(const FrozenFunction &) = delete; + FrozenFunction &operator=(const FrozenFunction &) = delete; + FrozenFunction(FrozenFunction &&) = default; + FrozenFunction &operator=(FrozenFunction &&) = default; + + uint32_t n_cached_recordings() { return this->recordings.size(); } + + void clear(); + + nb::object operator()(nb::args args, nb::kwargs kwargs); +}; + +extern void export_freeze(nb::module_ &); diff --git a/src/python/main.cpp b/src/python/main.cpp index 99712f3ec..27cb02dff 100644 --- a/src/python/main.cpp +++ b/src/python/main.cpp @@ -8,7 +8,7 @@ BSD-style license that can be found in the LICENSE.txt file. */ -#define NB_INTRUSIVE_EXPORT NB_EXPORT +#define NB_INTRUSIVE_EXPORT NB_IMPORT #include #include @@ -22,6 +22,7 @@ #include "cuda.h" #include "reduce.h" #include "eval.h" +#include "freeze.h" #include "iter.h" #include "init.h" #include "memop.h" @@ -106,6 +107,9 @@ NB_MODULE(_drjit_ext, m_) { .value("ScatterReduceLocal", JitFlag::ScatterReduceLocal, doc_JitFlag_ScatterReduceLocal) .value("SymbolicConditionals", JitFlag::SymbolicConditionals, doc_JitFlag_SymbolicConditionals) .value("SymbolicScope", JitFlag::SymbolicScope, doc_JitFlag_SymbolicScope) + .value("KernelFreezing", JitFlag::KernelFreezing, doc_JitFlag_KernelFreezing) + .value("FreezingScope", JitFlag::FreezingScope, doc_JitFlag_FreezingScope) + .value("EnableObjectTraversal", JitFlag::EnableObjectTraversal, doc_JitFlag_EnableObjectTraversal) .value("Default", JitFlag::Default, doc_JitFlag_Default) // Deprecated aliases @@ -235,6 +239,7 @@ NB_MODULE(_drjit_ext, m_) { export_iter(detail); export_reduce(m); export_eval(m); + export_freeze(m); export_memop(m); export_slice(m); export_dlpack(m); diff --git a/src/python/texture.h b/src/python/texture.h index e8d009856..4068ec226 100644 --- a/src/python/texture.h +++ b/src/python/texture.h @@ -173,6 +173,8 @@ void bind_texture(nb::module_ &m, const char *name) { #undef def_tex_eval_cubic_helper tex.attr("IsTexture") = true; + + drjit::bind_traverse(tex); } template diff --git a/src/python/tracker.cpp b/src/python/tracker.cpp index ca68abf5a..6c048ce3d 100644 --- a/src/python/tracker.cpp +++ b/src/python/tracker.cpp @@ -191,8 +191,8 @@ struct VariableTracker::Context { check_size(check_size), index_offset(0) { } // Internal API for type-erased traversal - uint64_t _traverse_write(uint64_t idx); - void _traverse_read(uint64_t index); + uint64_t _traverse_write(uint64_t idx, const char *, const char *); + void _traverse_read(uint64_t index, const char *, const char *); }; // Temporarily push a value onto the stack @@ -574,7 +574,8 @@ bool VariableTracker::Impl::traverse(Context &ctx, nb::handle h) { return changed; } -uint64_t VariableTracker::Context::_traverse_write(uint64_t idx) { +uint64_t VariableTracker::Context::_traverse_write(uint64_t idx, const char *, + const char *) { if (!idx) return 0; if (index_offset >= indices.size()) @@ -611,7 +612,7 @@ uint64_t VariableTracker::Context::_traverse_write(uint64_t idx) { return idx_new; } -void VariableTracker::Context::_traverse_read(uint64_t index) { +void VariableTracker::Context::_traverse_read(uint64_t index, const char *, const char *) { if (!index) return; indices.push_back(ad_var_inc_ref(index)); diff --git a/tests/call_ext.cpp b/tests/call_ext.cpp index 90dd869c2..8b5d85de7 100644 --- a/tests/call_ext.cpp +++ b/tests/call_ext.cpp @@ -6,6 +6,7 @@ #include #include #include +#include namespace nb = nanobind; namespace dr = drjit; @@ -13,29 +14,25 @@ namespace dr = drjit; using namespace nb::literals; template -struct Sampler { +struct Sampler : dr::TraversableBase { + Sampler() : rng(1) {} Sampler(size_t size) : rng(size) { } T next() { return rng.next_float32(); } - void traverse_1_cb_ro(void *payload, void (*fn)(void *, uint64_t)) const { - traverse_1_fn_ro(rng, payload, fn); - } - - void traverse_1_cb_rw(void *payload, uint64_t (*fn)(void *, uint64_t)) { - traverse_1_fn_rw(rng, payload, fn); - } - dr::PCG32> rng; + + DR_TRAVERSE_CB(dr::TraversableBase, rng); }; -template struct Base : nb::intrusive_base { +template struct Base : drjit::TraversableBase { using Mask = dr::mask_t; using UInt32 = dr::uint32_array_t; virtual std::pair f(Float x, Float y) = 0; virtual std::pair f_masked(const std::pair &xy, Mask active) = 0; virtual Float g(Float, Mask) = 0; + virtual Float h(Float) = 0; virtual Float nested(Float x, UInt32 s) = 0; virtual void dummy() = 0; virtual float scalar_getter() = 0; @@ -50,10 +47,12 @@ template struct Base : nb::intrusive_base { Base() { if constexpr (dr::is_jit_v) - jit_registry_put("", "Base", this); + drjit::registry_put("", "Base", this); } virtual ~Base() { jit_registry_remove(this); } + + DR_TRAVERSE_CB(drjit::TraversableBase) }; template struct A : Base { @@ -74,6 +73,10 @@ template struct A : Base { return value; } + virtual Float h(Float x) override{ + return value + x; + } + virtual Float nested(Float x, UInt32 /*s*/) override { return x + dr::gather(value, UInt32(0)); } @@ -112,6 +115,8 @@ template struct A : Base { uint32_t scalar_property; Float value, extra_value; Float opaque = dr::opaque(1.f); + + DR_TRAVERSE_CB(Base, value, opaque) }; template struct B : Base { @@ -132,6 +137,10 @@ template struct B : Base { return value*x; } + virtual Float h(Float x) override{ + return value - x; + } + virtual Float nested(Float x, UInt32 s) override { using BaseArray = dr::replace_value_t*>; BaseArray self = dr::reinterpret_array(s); @@ -160,6 +169,8 @@ template struct B : Base { Float value; Float opaque = dr::opaque(2.f); + + DR_TRAVERSE_CB(Base, value, opaque) }; DRJIT_CALL_TEMPLATE_BEGIN(Base) @@ -167,6 +178,7 @@ DRJIT_CALL_TEMPLATE_BEGIN(Base) DRJIT_CALL_METHOD(f_masked) DRJIT_CALL_METHOD(dummy) DRJIT_CALL_METHOD(g) + DRJIT_CALL_METHOD(h) DRJIT_CALL_METHOD(nested) DRJIT_CALL_METHOD(sample) DRJIT_CALL_METHOD(gather_packet) @@ -198,20 +210,22 @@ void bind(nb::module_ &m) { using Sampler = ::Sampler; auto sampler = nb::class_(m, "Sampler") + .def(nb::init<>()) .def(nb::init()) .def("next", &Sampler::next) .def_rw("rng", &Sampler::rng); bind_traverse(sampler); - nb::class_(m, "Base") + auto base_cls = nb::class_(m, "Base") .def("f", &BaseT::f) .def("f_masked", &BaseT::f_masked) .def("g", &BaseT::g) .def("nested", &BaseT::nested) .def("sample", &BaseT::sample); + bind_traverse(base_cls); - nb::class_(m, "A") + auto a_cls = nb::class_(m, "A") .def(nb::init<>()) .def("a_get_property", &AT::a_get_property) .def("a_gather_extra_value", &AT::a_gather_extra_value) @@ -219,11 +233,13 @@ void bind(nb::module_ &m) { .def_rw("value", &AT::value) .def_rw("extra_value", &AT::extra_value) .def_rw("scalar_property", &AT::scalar_property); + bind_traverse(a_cls); - nb::class_(m, "B") + auto b_cls = nb::class_(m, "B") .def(nb::init<>()) .def_rw("opaque", &BT::opaque) .def_rw("value", &BT::value); + bind_traverse(b_cls); using BaseArray = dr::DiffArray; m.def("dispatch_f", [](BaseArray &self, Float a, Float b) { @@ -243,6 +259,7 @@ void bind(nb::module_ &m) { .def("g", [](BaseArray &self, Float x, Mask m) { return self->g(x, m); }, "x"_a, "mask"_a = true) + .def("h", [](BaseArray &self, Float x) { return self->h(x); }, "x"_a) .def("nested", [](BaseArray &self, Float x, UInt32 s) { return self->nested(x, s); }, "x"_a, "s"_a) diff --git a/tests/custom_type_ext.cpp b/tests/custom_type_ext.cpp index 50c936ee0..7bff08779 100644 --- a/tests/custom_type_ext.cpp +++ b/tests/custom_type_ext.cpp @@ -1,6 +1,10 @@ -#include #include #include +#include +#include +#include +#include +#include namespace nb = nanobind; namespace dr = drjit; @@ -42,6 +46,65 @@ struct CustomHolder { Value m_value; }; +class Object : public drjit::TraversableBase { + DR_TRAVERSE_CB(drjit::TraversableBase); +}; + +template +class CustomBase : public Object{ + Value m_base_value; + +public: + CustomBase(const Value &base_value) : Object(), m_base_value(base_value) {} + + Value &base_value() { return m_base_value; } + virtual Value &value() = 0; + + DR_TRAVERSE_CB(Object, m_base_value); +}; + +template +class PyCustomBase : public CustomBase{ +public: + using Base = CustomBase; + NB_TRAMPOLINE(Base, 1); + + PyCustomBase(const Value &base_value) : Base(base_value) {} + + Value &value() override { NB_OVERRIDE_PURE(value); } + + DR_TRAMPOLINE_TRAVERSE_CB(Base); +}; + +template +class CustomA: public CustomBase{ +public: + using Base = CustomBase; + + CustomA(const Value &value, const Value &base_value) : Base(base_value), m_value(value) {} + + Value &value() override { return m_value; } + +private: + Value m_value; + + DR_TRAVERSE_CB(Base, m_value); +}; + +template +class Nested: Object{ + using Base = Object; + + std::vector, size_t>> m_nested; + +public: + Nested(nb::ref a, nb::ref b) { + m_nested.push_back(std::make_pair(a, 0)); + m_nested.push_back(std::make_pair(b, 1)); + } + + DR_TRAVERSE_CB(Base, m_nested); +}; template void bind(nb::module_ &m) { dr::ArrayBinding b; @@ -64,12 +127,50 @@ template void bind(nb::module_ &m) { .def(nb::init()) .def("value", &CustomFloatHolder::value, nanobind::rv_policy::reference); + using CustomBase = CustomBase; + using PyCustomBase = PyCustomBase; + using CustomA = CustomA; + using Nested = Nested; + + auto object = nb::class_( + m, "Object", + nb::intrusive_ptr( + [](Object *o, PyObject *po) noexcept { o->set_self_py(po); })); + + auto base = + nb::class_(m, "CustomBase") + .def(nb::init()) + .def("value", nb::overload_cast<>(&CustomBase::value)) + .def("base_value", nb::overload_cast<>(&CustomBase::base_value)); + + drjit::bind_traverse(base); + + auto a = nb::class_(m, "CustomA") + .def(nb::init()); + + drjit::bind_traverse(a); + + auto nested = nb::class_(m, "Nested") + .def(nb::init, nb::ref>()); + + drjit::bind_traverse(nested); + m.def("cpp_make_opaque", [](CustomFloatHolder &holder) { dr::make_opaque(holder); } ); } NB_MODULE(custom_type_ext, m) { + nb::intrusive_init( + [](PyObject *o) noexcept { + nb::gil_scoped_acquire guard; + Py_INCREF(o); + }, + [](PyObject *o) noexcept { + nb::gil_scoped_acquire guard; + Py_DECREF(o); + }); + #if defined(DRJIT_ENABLE_LLVM) nb::module_ llvm = m.def_submodule("llvm"); bind(llvm); diff --git a/tests/test_custom_type_ext.py b/tests/test_custom_type_ext.py index 90c9b7a23..ad97bb3ec 100644 --- a/tests/test_custom_type_ext.py +++ b/tests/test_custom_type_ext.py @@ -1,7 +1,6 @@ import drjit as dr import pytest - def get_pkg(t): with dr.detail.scoped_rtld_deepbind(): m = pytest.importorskip("custom_type_ext") @@ -69,3 +68,96 @@ def test03_cpp_make_opaque(t): pkg.cpp_make_opaque(holder) assert holder.value().state == dr.VarState.Evaluated + + +@pytest.test_arrays("float32,-diff,shape=(*),jit") +def test04_traverse_opaque(t): + """ + Tests that it is possible to traverse an opaque C++ object. + """ + pkg = get_pkg(t) + Float = t + + value = dr.arange(Float, 10) + base_value = dr.arange(Float, 10) + + a = pkg.CustomA(value, base_value) + assert dr.detail.collect_indices(a) == [base_value.index, value.index] + + +@pytest.test_arrays("float32,-diff,shape=(*),jit") +def test05_traverse_py(t): + """ + Tests the implementation of ``traverse_py_cb_ro``, which is used to traverse + python objects in trampoline classes. + """ + Float = t + + v = dr.arange(Float, 10) + + class PyClass: + def __init__(self, v) -> None: + self.v = v + + c = PyClass(v) + + result = [] + + def callback(index, domain, variant): + result.append(index) + + dr.detail.traverse_py_cb_ro(c, callback) + + assert result == [v.index] + + +@pytest.test_arrays("float32,-diff,shape=(*),jit") +def test06_trampoline_traversal(t): + """ + Tests that classes inheriting from trampoline classes are traversed + automatically. + """ + pkg = get_pkg(t) + Float = t + + value = dr.opaque(Float, 0, 3) + base_value = dr.opaque(Float, 1, 3) + + class B(pkg.CustomBase): + def __init__(self, value, base_value) -> None: + super().__init__(base_value) + self._value = value + + def value(self): + return self._value + + b = B(value, base_value) + + assert dr.detail.collect_indices(b) == [base_value.index, value.index] + +@pytest.test_arrays("float32,-diff,shape=(*),jit") +def test07_nested_traversal(t): + """ + Test traversal of nested objects, and more specifically the traversal of + ``std::vector, size_t>>`` members. + """ + pkg = get_pkg(t) + Float = t + + value = dr.arange(Float, 10) + 0 + base_value = dr.arange(Float, 10) + 1 + + a = pkg.CustomA(value, base_value) + + value = dr.arange(Float, 10) + 2 + base_value = dr.arange(Float, 10) + 3 + + b = pkg.CustomA(value, base_value) + + nested = pkg.Nested(a, b) + + indices_a = dr.detail.collect_indices(a) + indices_b = dr.detail.collect_indices(b) + indices_nested = dr.detail.collect_indices(nested) + + assert indices_nested == indices_a + indices_b diff --git a/tests/test_freeze.py b/tests/test_freeze.py new file mode 100644 index 000000000..fc34a3d40 --- /dev/null +++ b/tests/test_freeze.py @@ -0,0 +1,2824 @@ +import drjit as dr +import pytest +from math import ceil +from dataclasses import dataclass +import sys + +def get_single_entry(x): + tp = type(x) + result = x + shape = dr.shape(x) + if len(shape) == 2: + result = result[shape[0] - 1] + if len(shape) == 3: + result = result[shape[0] - 1][shape[1] - 1] + return result + + +@pytest.test_arrays("uint32, jit, shape=(*)") +def test01_basic(t): + """ + Tests a very basic frozen function, adding two integers x, y. + """ + + @dr.freeze + def func(x, y): + return x + y + + i0 = t(0, 1, 2) + i1 = t(2, 1, 0) + + o0 = func(i0, i1) + assert dr.all(t(2, 2, 2) == o0) + + i0 = t(1, 2, 3) + i1 = t(3, 2, 1) + + o0 = func(i0, i1) + assert dr.all(t(4, 4, 4) == o0) + + +@pytest.test_arrays("uint32, jit, shape=(*)") +def test02_flush_kernel_cache(t): + """ + Tests that flushing the kernel between recording and replaying a frozen + function causes the function to be re-traced. + """ + + def func(x, y): + return x + y + + frozen = dr.freeze(func) + + x = t(0, 1, 2) + y = t(2, 1, 0) + + res = frozen(x, y) + ref = func(x, y) + assert dr.all(res == ref) + + dr.flush_kernel_cache() + + x = t(1, 2, 3) + y = t(3, 2, 1) + + # Flushing the kernel cache should force a re-trace + res = frozen(x, y) + ref = func(x, y) + assert dr.all(res == ref) + assert frozen.n_recordings == 2 + + +@pytest.test_arrays("uint32, jit, shape=(*)") +def test03_output_tuple(t): + """ + Tests that returning tuples from frozen functions is possible. + """ + + @dr.freeze + def func(x, y): + return (x + y, x * y) + + i0 = t(0, 1, 2) + i1 = t(2, 1, 0) + + (o0, o1) = func(i0, i1) + assert dr.all(t(2, 2, 2) == o0) + assert dr.all(t(0, 1, 0) == o1) + + i0 = t(1, 2, 3) + i1 = t(3, 2, 1) + + (o0, o1) = func(i0, i1) + assert dr.all(t(4, 4, 4) == o0) + assert dr.all(t(3, 4, 3) == o1) + + +@pytest.test_arrays("uint32, jit, shape=(*)") +def test04_output_list(t): + """ + Tests that returning lists from forzen functions is possible. + """ + + @dr.freeze + def func(x, y): + return [x + y, x * y] + + i0 = t(0, 1, 2) + i1 = t(2, 1, 0) + + [o0, o1] = func(i0, i1) + assert dr.all(t(2, 2, 2) == o0) + assert dr.all(t(0, 1, 0) == o1) + + i0 = t(1, 2, 3) + i1 = t(3, 2, 1) + + [o0, o1] = func(i0, i1) + assert dr.all(t(4, 4, 4) == o0) + assert dr.all(t(3, 4, 3) == o1) + + +@pytest.test_arrays("uint32, jit, shape=(*)") +def test05_output_dict(t): + """ + Tests that returning dictionaries from forzen functions is possible. + """ + + @dr.freeze + def func(x, y): + return {"add": x + y, "mul": x * y} + + i0 = t(0, 1, 2) + i1 = t(2, 1, 0) + + o = func(i0, i1) + o0 = o["add"] + o1 = o["mul"] + assert dr.all(t(2, 2, 2) == o0) + assert dr.all(t(0, 1, 0) == o1) + + i0 = t(1, 2, 3) + i1 = t(3, 2, 1) + + o = func(i0, i1) + o0 = o["add"] + o1 = o["mul"] + assert dr.all(t(4, 4, 4) == o0) + assert dr.all(t(3, 4, 3) == o1) + + +@pytest.test_arrays("uint32, jit, shape=(*)") +def test06_nested_tuple(t): + """ + Tests that returning nested tuples from forzen functions is possible. + """ + + @dr.freeze + def func(x): + return (x + 1, x + 2, (x + 3, x + 4)) + + i0 = t(0, 1, 2) + + (o0, o1, (o2, o3)) = func(i0) + assert dr.all(t(1, 2, 3) == o0) + assert dr.all(t(2, 3, 4) == o1) + assert dr.all(t(3, 4, 5) == o2) + assert dr.all(t(4, 5, 6) == o3) + + i0 = t(1, 2, 3) + + +@pytest.test_arrays("uint32, jit, shape=(*)") +def test07_drjit_struct(t): + """ + Tests that returning custom classes, annotated with ``DRJIT_STRUCT`` from + forzen functions is possible. + """ + + class Point: + x: t + y: t + DRJIT_STRUCT = {"x": t, "y": t} + + @dr.freeze + def func(x): + p = Point() + p.x = x + 1 + p.y = x + 2 + return p + + i0 = t(0, 1, 2) + + o = func(i0) + o0 = o.x + o1 = o.y + assert dr.all(t(1, 2, 3) == o0) + assert dr.all(t(2, 3, 4) == o1) + + i0 = t(1, 2, 3) + + o = func(i0) + o0 = o.x + o1 = o.y + assert dr.all(t(2, 3, 4) == o0) + assert dr.all(t(3, 4, 5) == o1) + + +@pytest.test_arrays("uint32, jit, shape=(*)") +def test08_dataclass(t): + """ + Tests that returning custom dataclasses from forzen functions is possible. + """ + + @dataclass + class Point: + x: t + y: t + + @dr.freeze + def func(x): + p = Point(x + 1, x + 2) + return p + + i0 = t(0, 1, 2) + + o = func(i0) + o0 = o.x + o1 = o.y + assert dr.all(t(1, 2, 3) == o0) + assert dr.all(t(2, 3, 4) == o1) + + i0 = t(1, 2, 3) + + o = func(i0) + o0 = o.x + o1 = o.y + assert dr.all(t(2, 3, 4) == o0) + assert dr.all(t(3, 4, 5) == o1) + + +@pytest.test_arrays("float32, jit, shape=(*)") +def test09_traverse_cb(t): + """ + Tests that passing opaque C++ objects to frozen functions is possible. + It should not be possible to return these from frozen functions. + """ + pkg = get_pkg(t) + Sampler = pkg.Sampler + + def func(sampler): + return sampler.next() + + frozen = dr.freeze(func) + + sampler_frozen = Sampler(10) + sampler_func = Sampler(10) + + result1_frozen = frozen(sampler_frozen) + result1_func = func(sampler_func) + assert dr.allclose(result1_frozen, result1_func) + + sampler_frozen = Sampler(10) + sampler_func = Sampler(10) + + result2_frozen = frozen(sampler_frozen) + result2_func = func(sampler_func) + assert dr.allclose(result2_frozen, result2_func) + + assert frozen.n_recordings == 1 + + result3_frozen = frozen(sampler_frozen) + result3_func = func(sampler_func) + assert dr.allclose(result3_func, result3_frozen) + + assert frozen.n_recordings == 1 + + +@pytest.test_arrays("uint32, jit, shape=(*)") +def test10_scatter(t): + """ + Tests that it is possible to scatter to the input of a frozen function, + while leaving variables depending on the input the same (scattering problem). + """ + + @dr.freeze + def func(x): + dr.scatter(x, 0, dr.arange(t, 3)) + + x = t(0, 1, 2) + func(x) + + x = t(0, 1, 2) + y = x + 1 + z = x + w = t(x) + + func(x) + + assert dr.all(t(0, 0, 0) == x) + assert dr.all(t(1, 2, 3) == y) + assert dr.all(t(0, 0, 0) == z) + assert dr.all(t(0, 1, 2) == w) + + +@pytest.test_arrays("float32, jit, is_diff, shape=(*)") +def test11_optimization(t): + """ + Implements a simple gradient descent optimization of a variable in a + frozen function. This verifies that gradient descent kernels are evaluated + correctly. + """ + + @dr.freeze + def func(state, ref): + for k, x in state.items(): + dr.enable_grad(x) + loss = dr.mean(dr.square(x - ref)) + + dr.backward(loss) + + grad = dr.grad(x) + dr.disable_grad(x) + state[k] = x - grad + + state = {"x": t(0, 0, 0, 0)} + + ref = t(1, 1, 1, 1) + func(state, ref) + assert dr.allclose(t(0.5, 0.5, 0.5, 0.5), state["x"]) + + state = {"x": t(0, 0, 0, 0)} + ref = t(1, 1, 1, 1) + func(state, ref) + + assert dr.allclose(t(0.5, 0.5, 0.5, 0.5), state["x"]) + + +@pytest.test_arrays("uint32, jit, shape=(*)") +def test12_resized(t): + """ + Tests that it is possible to call a frozen function with inputs of different + size compared to the recording without having to re-trace the function. + """ + + @dr.freeze + def func(x, y): + return x + y + + i0 = t(0, 1, 2) + i1 = t(2, 1, 0) + + o0 = func(i0, i1) + assert dr.all(t(2, 2, 2) == o0) + + i0 = dr.arange(t, 64) + dr.opaque(t, 0) + i1 = dr.arange(t, 64) + dr.opaque(t, 0) + r0 = i0 + i1 + dr.eval(i0, i1, r0) + + o0 = func(i0, i1) + assert dr.all(r0 == o0) + assert func.n_recordings == 1 + + +@pytest.test_arrays("uint32, jit, shape=(*)") +def test13_changed_input_dict(t): + """ + Test that it is possible to pass a dictionary to a frozen function, that is + inserting the result at a new key in said dictionary. This ensures that the + input is evaluated correctly, and the dictionary is back-assigned to the input. + """ + + @dr.freeze + def func(d: dict): + d["y"] = d["x"] + 1 + + x = t(0, 1, 2) + d = {"x": x} + + func(d) + assert dr.all(t(1, 2, 3) == d["y"]) + + x = t(1, 2, 3) + d = {"x": x} + + func(d) + assert dr.all(t(2, 3, 4) == d["y"]) + + +@pytest.test_arrays("uint32, jit, shape=(*)") +def test14_changed_input_dataclass(t): + """ + Tests that it is possible to asing to the input of a dataclass inside a + frozen function. This also relies on correct back-assignment of the input. + """ + + @dataclass + class Point: + x: t + + @dr.freeze + def func(p: Point): + p.x = p.x + 1 + + p = Point(x=t(0, 1, 2)) + + func(p) + assert dr.all(t(1, 2, 3) == p.x) + + p = Point(x=t(1, 2, 3)) + + func(p) + assert dr.all(t(2, 3, 4) == p.x) + + +@pytest.test_arrays("uint32, jit, shape=(*)") +def test15_kwargs(t): + """ + Tests that it is possible to pass keyword arguments to a frozen function + that modifies them. + """ + + @dr.freeze + def func(x=t(0, 1, 2)): + return x + 1 + + y = func(x=t(0, 1, 2)) + assert dr.all(t(1, 2, 3) == y) + + y = func(x=t(1, 2, 3)) + assert dr.all(t(2, 3, 4) == y) + + +@pytest.test_arrays("uint32, jit, shape=(*)") +def test16_opaque(t): + """ + Tests that changing from an opaque (1-sized array) to an array of size + larger than 1 causes the funcion to be re-traced. This is necessary, because + different kernels are compiled for the two cases. + """ + + @dr.freeze + def func(x, y): + return x + y + + x = t(0, 1, 2) + dr.set_label(x, "x") + y = dr.opaque(t, 1) + dr.set_label(y, "y") + z = func(x, y) + assert dr.all(t(1, 2, 3) == z) + + x = t(1, 2, 3) + y = t(1, 2, 3) + z = func(x, y) + assert dr.all(t(2, 4, 6) == z) + + assert func.n_recordings == 2 + + +@pytest.test_arrays("float32, jit, -is_diff, shape=(*)") +def test17_performance(t): + """ + Tests the performance of a frozen function versus a non-frozen function. + """ + import time + + n = 1024 + n_iter = 1_000 + n_iter_warmeup = 10 + + def func(x, y): + z = 0.5 + result = dr.fma(dr.square(x), y, z) + result = dr.sqrt(dr.abs(result) + dr.power(result, 10)) + result = dr.log(1 + result) + return result + + frozen = dr.freeze(func) + + for name, fn in [("normal", func), ("frozen", frozen)]: + x = dr.arange(t, n) # + dr.opaque(t, i) + y = dr.arange(t, n) # + dr.opaque(t, i) + dr.eval(x, y) + for i in range(n_iter + n_iter_warmeup): + if i == n_iter_warmeup: + t0 = time.time() + + result = fn(x, y) + + dr.eval(result) + + dr.sync_thread() + elapsed = time.time() - t0 + + +@pytest.test_arrays("uint32, jit, shape=(*)") +def test18_aliasing(t): + """ + Tests that changing the inputs from being the same variable to two different + variables causes the function to be re-traced. + """ + + @dr.freeze + def func(x, y): + return x + y + + x = t(0, 1, 2) + y = x + z = func(x, y) + assert dr.all(t(0, 2, 4) == z) + + x = t(1, 2, 3) + y = x + z = func(x, y) + assert dr.all(t(2, 4, 6) == z) + + x = t(1, 2, 3) + y = t(2, 3, 4) + z = func(x, y) + assert dr.all(t(3, 5, 7) == z) + assert func.n_recordings == 2 + + +@pytest.test_arrays("uint32, jit, shape=(*)") +def test19_non_jit_types(t): + """ + Tests that it is possible to pass non-jit types such as integers to frozen + functions. + """ + + def func(x, y): + return x + y + + frozen = dr.freeze(func) + + for i in range(3): + x = t(1, 2, 3) + y = i + + res = frozen(x, y) + ref = func(x, y) + assert dr.all(res == ref) + + assert frozen.n_recordings == 3 + + +@pytest.test_arrays("uint32, jit, shape=(*)") +def test20_literal(t): + """ + Test that drjit literals, passed to frozen functions do not cause the + function to be re-traced if they change. This is enabled by making the input + opaque. + """ + + @dr.freeze + def func(x, y): + z = x + y + w = t(1) + return z, w + + # Literals + x = t(0, 1, 2) + dr.set_label(x, "x") + y = t(1) + dr.set_label(y, "y") + z, w = func(x, y) + assert dr.all(z == t(1, 2, 3)) + assert dr.all(w == t(1)) + + x = t(0, 1, 2) + y = t(1) + z, w = func(x, y) + assert dr.all(z == t(1, 2, 3)) + assert dr.all(w == t(1)) + + assert func.n_recordings == 1 + + x = t(0, 1, 2) + y = t(2) + z = func(x, y) + assert func.n_recordings == 1 + + +@pytest.test_arrays("uint32, jit, shape=(*)") +def test21_pointers(t): + """ + Test that it is possible to gather from a same-sized variable. This tests + the kernel size inference algorithm as well as having two kernels in a + frozen function. + """ + UInt32 = dr.uint32_array_t(t) + + @dr.freeze + def func(x): + idx = dr.arange(UInt32, 0, dr.width(x), 3) + + return dr.gather(t, x, idx) + + y = func(t(0, 1, 2, 3, 4, 5, 6)) + + +@pytest.test_arrays("uint32, jit, shape=(*)") +def test22_gather_memcpy(t): + """ + The gather operation might be elided in favor of a memcpy + if the index is a literal of size 1. + The source of the memcpy is however not known to the recording + mechansim as it might index into the source array. + """ + + def func(x, idx: int): + idx = t(idx) + return dr.gather(t, x, idx) + + frozen = dr.freeze(func) + + for i in range(3): + x = dr.arange(t, i, 3 + i) + dr.make_opaque(x) + ref = func(x, 2) + result = frozen(x, 2) + + assert dr.all(ref == result) + + assert frozen.n_recordings == 1 + + +def get_pkg(t): + with dr.detail.scoped_rtld_deepbind(): + m = pytest.importorskip("call_ext") + backend = dr.backend_v(t) + if backend == dr.JitBackend.LLVM: + return m.llvm + elif backend == dr.JitBackend.CUDA: + return m.cuda + + +@pytest.mark.parametrize("symbolic", [True]) +@pytest.test_arrays("float32, jit, -is_diff, shape=(*)") +def test23_vcall(t, symbolic): + """ + Tests a basic symbolic vcall being called inside a frozen function. + """ + pkg = get_pkg(t) + + A, B, Base, BasePtr = pkg.A, pkg.B, pkg.Base, pkg.BasePtr + Mask = dr.mask_t(t) + a, b = A(), B() + + c = BasePtr(a, a, None, a, a) + + xi = t(1, 2, 8, 3, 4) + yi = t(5, 6, 8, 7, 8) + + @dr.freeze + def func(c, xi, yi): + return c.f(xi, yi) + + with dr.scoped_set_flag(dr.JitFlag.SymbolicCalls, symbolic): + xo, yo = func(c, xi, yi) + + assert dr.all(xo == t(10, 12, 0, 14, 16)) + assert dr.all(yo == t(-1, -2, 0, -3, -4)) + + c = BasePtr(a, a, None, b, b) + + with dr.scoped_set_flag(dr.JitFlag.SymbolicCalls, symbolic): + xo, yo = func(c, xi, yi) + + assert func.n_recordings == 1 + + assert dr.all(xo == t(10, 12, 0, 21, 24)) + assert dr.all(yo == t(-1, -2, 0, 3, 4)) + + +@pytest.mark.parametrize("symbolic", [True]) +@pytest.mark.parametrize("optimize", [True, False]) +@pytest.mark.parametrize("opaque", [True, False]) +@pytest.test_arrays("float32, -is_diff, jit, shape=(*)") +def test24_vcall_optimize(t, symbolic, optimize, opaque): + """ + Test a basic vcall being called inside a frozen function, with the + "OptimizeCalls" flag either being set or not set. As well as opaque and + non-opaque inputs. + """ + pkg = get_pkg(t) + + A, B, Base, BasePtr = pkg.A, pkg.B, pkg.Base, pkg.BasePtr + Mask = dr.mask_t(t) + a, b = B(), B() + + dr.set_label(A.opaque, "A.opaque") + dr.set_label(B.opaque, "B.opaque") + + a.value = t(2) + b.value = t(3) + + if opaque: + dr.make_opaque(a.value, b.value) + + c = BasePtr(a, a, None, a, a) + dr.set_label(c, "c") + + x = t(1, 2, 8, 3, 4) + dr.set_label(x, "x") + + def func(c, xi): + return c.g(xi) + + frozen = dr.freeze(func) + + with dr.scoped_set_flag(dr.JitFlag.SymbolicCalls, symbolic): + with dr.scoped_set_flag(dr.JitFlag.OptimizeCalls, optimize): + xo = frozen(c, x) + + assert dr.all(xo == func(c, x)) + + a.value = t(3) + b.value = t(2) + + if opaque: + dr.make_opaque(a.value, b.value) + + c = BasePtr(a, a, None, b, b) + dr.set_label(c, "c") + + with dr.scoped_set_flag(dr.JitFlag.SymbolicCalls, symbolic): + with dr.scoped_set_flag(dr.JitFlag.OptimizeCalls, optimize): + xo = frozen(c, x) + + assert frozen.n_recordings == 1 + assert dr.all(xo == func(c, x)) + + +@pytest.mark.parametrize("symbolic", [True]) +@pytest.mark.parametrize("optimize", [True, False]) +@pytest.mark.parametrize("opaque", [True, False]) +@pytest.test_arrays("float32, -is_diff, jit, shape=(*)") +def test25_multiple_vcalls(t, symbolic, optimize, opaque): + """ + Test calling multiple vcalls in a frozen function, where the result of the + first is used as the input to the second. + """ + pkg = get_pkg(t) + + A, B, Base, BasePtr = pkg.A, pkg.B, pkg.Base, pkg.BasePtr + Mask = dr.mask_t(t) + a, b = B(), B() + + dr.set_label(A.opaque, "A.opaque") + dr.set_label(B.opaque, "B.opaque") + + a.value = t(2) + b.value = t(3) + + if opaque: + dr.make_opaque(a.value, b.value) + + c = BasePtr(a, a, None, a, a) + dr.set_label(c, "c") + + x = t(1, 2, 8, 3, 4) + dr.set_label(x, "x") + + def func(c, xi): + x = c.h(xi) + dr.make_opaque(x) + return c.g(x) + + frozen = dr.freeze(func) + + with dr.scoped_set_flag(dr.JitFlag.SymbolicCalls, symbolic): + with dr.scoped_set_flag(dr.JitFlag.OptimizeCalls, optimize): + xo = frozen(c, x) + + assert dr.all(xo == func(c, x)) + + a.value = t(3) + b.value = t(2) + + if opaque: + dr.make_opaque(a.value, b.value) + + c = BasePtr(a, a, None, b, b) + dr.set_label(c, "c") + + with dr.scoped_set_flag(dr.JitFlag.SymbolicCalls, symbolic): + with dr.scoped_set_flag(dr.JitFlag.OptimizeCalls, optimize): + xo = frozen(c, x) + + assert frozen.n_recordings == 1 + + assert dr.all(xo == func(c, x)) + + +@pytest.test_arrays("float32, jit, shape=(*)") +def test26_freeze(t): + """ + Test freezing a simple frozen function. + """ + UInt32 = dr.uint32_array_t(t) + Float = dr.float32_array_t(t) + + @dr.freeze + def my_kernel(x): + x_int = UInt32(x) + result = x * x + result_int = UInt32(result) + + return result, x_int, result_int + + for i in range(3): + x = Float([1.0, 2.0, 3.0]) + dr.opaque(Float, i) + + y, x_int, y_int = my_kernel(x) + dr.schedule(y, x_int, y_int) + assert dr.allclose(y, dr.square(x)) + assert dr.allclose(x_int, UInt32(x)) + assert dr.allclose(y_int, UInt32(y)) + + +@pytest.mark.parametrize("freeze_first", (True, False)) +@pytest.test_arrays("float32, jit, shape=(*)") +def test27_calling_frozen_from_frozen(t, freeze_first): + """ + Test calling a frozen function from within another frozen function. + The inner frozen function should behave as a normal function. + """ + mod = sys.modules[t.__module__] + Float = mod.Float32 + Array3f = mod.Array3f + n = 37 + x = dr.full(Float, 1.5, n) + dr.opaque(Float, 2) + y = dr.full(Float, 0.5, n) + dr.opaque(Float, 10) + dr.eval(x, y) + + @dr.freeze + def fun1(x): + return dr.square(x) + + @dr.freeze + def fun2(x, y): + return fun1(x) + fun1(y) + + # Calling a frozen function from a frozen function. + if freeze_first: + dr.eval(fun1(x)) + + result1 = fun2(x, y) + assert dr.allclose(result1, dr.square(x) + dr.square(y)) + + if not freeze_first: + # If the nested function hasn't been recorded yet, calling it + # while freezing the outer function shouldn't freeze it with + # those arguments. + # In other words, any freezing mechanism should be completely + # disabled while recording a frozen function. + # assert fun1.frozen.kernels is None + + # We should therefore be able to freeze `fun1` with a different + # type of argument, and both `fun1` and `fun2` should work fine. + result2 = fun1(Array3f(0.5, x, y)) + assert dr.allclose(result2, Array3f(0.5 * 0.5, dr.square(x), dr.square(y))) + + result3 = fun2(2 * x, 0.5 * y) + assert dr.allclose(result3, dr.square(2 * x) + dr.square(0.5 * y)) + + +@pytest.test_arrays("float32, jit, shape=(*)") +def test28_recorded_size(t): + """ + Tests that a frozen function, producing a variable with a constant size, + can be replayed and produces an output of the same size. + """ + mod = sys.modules[t.__module__] + UInt32 = mod.UInt32 + Float = mod.Float + + @dr.freeze + def fun(a): + x = t(dr.linspace(Float, -1, 1, 10)) + a + source = x + 2 * x + # source = get_single_entry(x + 2 * x) + index = dr.arange(UInt32, dr.width(source)) + active = index % UInt32(2) != 0 + + return dr.gather(Float, source, index, active) + + a = t(0.1) + res1 = fun(a) + res2 = fun(a) + res3 = fun(a) + + assert dr.allclose(res1, res2) + assert dr.allclose(res1, res3) + + +@pytest.test_arrays("float32, jit, shape=(*)") +def test29_with_gathers(t): + """ + Test gathering from an array at every second index in a frozen function. + """ + import numpy as np + + n = 20 + mod = sys.modules[t.__module__] + UInt32 = mod.UInt32 + + rng = np.random.default_rng(seed=1234) + shape = tuple(reversed(dr.shape(dr.zeros(t, n)))) + + def fun(x, idx): + active = idx % 2 != 0 + source = get_single_entry(x) + return dr.gather(type(source), source, idx, active=active) + + fun_frozen = dr.freeze(fun) + + # 1. Recording call + x1 = t(rng.uniform(low=-1, high=1, size=shape)) + idx1 = dr.arange(UInt32, n) + result1 = fun_frozen(x1, idx1) + assert dr.allclose(result1, fun(x1, idx1)) + + # 2. Different source as during recording + x2 = t(rng.uniform(low=-2, high=-1, size=shape)) + idx2 = idx1 + + result2 = fun_frozen(x2, idx2) + assert dr.allclose(result2, fun(x2, idx2)) + + x3 = x2 + idx3 = UInt32([i for i in reversed(range(n))]) + result3 = fun_frozen(x3, idx3) + assert dr.allclose(result3, fun(x3, idx3)) + + # 3. Same source as during recording + result4 = fun_frozen(x1, idx1) + assert dr.allclose(result4, result1) + + +@pytest.test_arrays("float32, jit, shape=(*)") +def test30_scatter_with_op(t): + """ + Tests scattering into the input of a frozen function. + """ + import numpy as np + + n = 16 + mod = sys.modules[t.__module__] + UInt32 = mod.UInt32 + + rng = np.random.default_rng(seed=1234) + + def func(x, idx): + active = idx % 2 != 0 + + result = x - 0.5 + dr.scatter(x, result, idx, active=active) + return result + + func_frozen = dr.freeze(func) + + # 1. Recording call + x1 = t(rng.uniform(low=-1, high=1, size=[n])) + x1_copy = t(x1) + x1_copy_copy = t(x1) + idx1 = dr.arange(UInt32, n) + + result1 = func_frozen(x1, idx1) + + # assert dr.allclose(x1, x1_copy) + assert dr.allclose(result1, func(x1_copy, idx1)) + + # 2. Different source as during recording + # TODO: problem: during trace, the actual x1 Python variable changes + # from index r2 to index r12 as a result of the `scatter`. + # But in subsequent launches, even if we successfully create a new + # output buffer equivalent to r12, it doesn't get assigned to `x2`. + x2 = t(rng.uniform(low=-2, high=-1, size=[n])) + x2_copy = t(x2) + idx2 = idx1 + + result2 = func_frozen(x2, idx2) + assert dr.allclose(result2, func(x2_copy, idx2)) + # assert dr.allclose(x2, x2_copy) + + x3 = x2 + x3_copy = t(x3) + idx3 = UInt32([i for i in reversed(range(n))]) + result3 = func_frozen(x3, idx3) + assert dr.allclose(result3, func(x3_copy, idx3)) + # assert dr.allclose(x3, x3_copy) + + # # 3. Same source as during recording + result4 = func_frozen(x1_copy_copy, idx1) + assert dr.allclose(result4, result1) + # # assert dr.allclose(x1_copy_copy, x1) + + +@pytest.test_arrays("float32, jit, shape=(*)") +def test31_with_gather_and_scatter(t): + """ + Tests a combination of scatters and gathers in a frozen function. + """ + + import numpy as np + + n = 20 + mod = sys.modules[t.__module__] + UInt32 = mod.UInt32 + + rng = np.random.default_rng(seed=1234) + shape = tuple(reversed(dr.shape(dr.zeros(t, n)))) + + def fun(x, idx): + active = idx % 2 != 0 + dest = get_single_entry(x) + + values = dr.gather(UInt32, idx, idx, active=active) + values = type(dest)(values) + dr.scatter(dest, values, idx, active=active) + return dest, values + + fun_frozen = dr.freeze(fun) + + # 1. Recording call + x1 = t(rng.uniform(low=-1, high=1, size=shape)) + x1_copy = t(x1) + x1_copy_copy = t(x1) + idx1 = dr.arange(UInt32, n) + + result1 = fun_frozen(x1, idx1) + assert dr.allclose(result1, fun(x1_copy, idx1)) + assert dr.allclose(x1, x1_copy) + + # 2. Different source as during recording + x2 = t(rng.uniform(low=-2, high=-1, size=shape)) + x2_copy = t(x2) + idx2 = idx1 + + result2 = fun_frozen(x2, idx2) + assert dr.allclose(result2, fun(x2_copy, idx2)) + assert dr.allclose(x2, x2_copy) + + x3 = x2 + x3_copy = t(x3) + idx3 = UInt32([i for i in reversed(range(n))]) + result3 = fun_frozen(x3, idx3) + assert dr.allclose(result3, fun(x3_copy, idx3)) + assert dr.allclose(x3, x3_copy) + + # 3. Same source as during recording + result4 = fun_frozen(x1_copy_copy, idx1) + assert dr.allclose(result4, result1) + assert dr.allclose(x1_copy_copy, x1) + + +@pytest.mark.parametrize("relative_size", ["<", "=", ">"]) +@pytest.test_arrays("float32, jit, shape=(*)") +def test32_gather_only_pointer_as_input(t, relative_size): + """ + Tests that it is possible to infer the launch size of kernels, if the width + of the resulting variable is a multiple/fraction of the variables from which + the result was gathered. + """ + mod = sys.modules[t.__module__] + Array3f = mod.Array3f + Float = mod.Float32 + UInt32 = mod.UInt32 + + import numpy as np + + rng = np.random.default_rng(seed=1234) + + if relative_size == "<": + + def fun(v): + idx = dr.arange(UInt32, 0, dr.width(v), 3) + return Array3f( + dr.gather(Float, v, idx), + dr.gather(Float, v, idx + 1), + dr.gather(Float, v, idx + 2), + ) + + elif relative_size == "=": + + def fun(v): + idx = dr.arange(UInt32, 0, dr.width(v)) // 2 + return Array3f( + dr.gather(Float, v, idx), + dr.gather(Float, v, idx + 1), + dr.gather(Float, v, idx + 2), + ) + + elif relative_size == ">": + + def fun(v): + max_width = dr.width(v) + idx = dr.arange(UInt32, 0, 5 * max_width) + # TODO(!): what can we do against this literal being baked into the kernel? + active = (idx + 2) < max_width + return Array3f( + dr.gather(Float, v, idx, active=active), + dr.gather(Float, v, idx + 1, active=active), + dr.gather(Float, v, idx + 2, active=active), + ) + + fun_freeze = dr.freeze(fun) + + def check_results(v, result): + size = v.size + if relative_size == "<": + expected = v.T + if relative_size == "=": + idx = np.arange(0, size) // 2 + expected = v.ravel() + expected = np.stack( + [ + expected[idx], + expected[idx + 1], + expected[idx + 2], + ], + axis=0, + ) + elif relative_size == ">": + idx = np.arange(0, 5 * size) + mask = (idx + 2) < size + expected = v.ravel() + expected = np.stack( + [ + np.where(mask, expected[(idx) % size], 0), + np.where(mask, expected[(idx + 1) % size], 0), + np.where(mask, expected[(idx + 2) % size], 0), + ], + axis=0, + ) + + assert np.allclose(result.numpy(), expected) + + # Note: Does not fail for n=1 + n = 7 + + for i in range(3): + v = rng.uniform(size=[n, 3]) + result = fun(Float(v.ravel())) + check_results(v, result) + + for i in range(10): + if i <= 5: + n_lanes = n + else: + n_lanes = n + i + + v = rng.uniform(size=[n_lanes, 3]) + result = fun_freeze(Float(v.ravel())) + + expected_width = { + "<": n_lanes, + "=": n_lanes * 3, + ">": n_lanes * 3 * 5, + }[relative_size] + + # if i == 0: + # assert len(fun_freeze.frozen.kernels) + # for kernel in fun_freeze.frozen.kernels.values(): + # assert kernel.original_input_size == n * 3 + # if relative_size == "<": + # assert kernel.original_launch_size == expected_width + # assert kernel.original_launch_size_ratio == (False, 3, True) + # elif relative_size == "=": + # assert kernel.original_launch_size == expected_width + # assert kernel.original_launch_size_ratio == (False, 1, True) + # else: + # assert kernel.original_launch_size == expected_width + # assert kernel.original_launch_size_ratio == (True, 5, True) + + assert dr.width(result) == expected_width + if relative_size == ">" and n_lanes != n: + pytest.xfail( + reason="The width() of the original input is baked into the kernel to compute the `active` mask during the first launch, so results are incorrect once the width changes." + ) + + check_results(v, result) + + +@pytest.test_arrays("float32, jit, shape=(*)") +def test33_multiple_kernels(t): + def fn(x: dr.ArrayBase, y: dr.ArrayBase, flag: bool): + # TODO: test with gathers and scatters, which is a really important use-case. + # TODO: test with launches of different sizes (including the auto-sizing logic) + # TODO: test with an intermediate output of literal type + # TODO: test multiple kernels that scatter_add to a newly allocated kernel in sequence. + + # First kernel uses only `x` + quantity = 0.5 if flag else -0.5 + intermediate1 = x + quantity + intermediate2 = x * quantity + dr.eval(intermediate1, intermediate2) + + # Second kernel uses `x`, `y` and one of the intermediate result + result = intermediate2 + y + + # The function returns some mix of outputs + return intermediate1, None, y, result + + n = 15 + x = dr.full(t, 1.5, n) + dr.opaque(t, 0.2) + y = dr.full(t, 0.5, n) + dr.opaque(t, 0.1) + dr.eval(x, y) + + ref_results = fn(x, y, flag=True) + dr.eval(ref_results) + + fn_frozen = dr.freeze(fn) + for _ in range(2): + results = fn_frozen(x, y, flag=True) + assert dr.allclose(results[0], ref_results[0]) + assert results[1] is None + assert dr.allclose(results[2], y) + assert dr.allclose(results[3], ref_results[3]) + + # TODO: + # We don't yet make a difference between check and no-check + + # for i in range(4): + # new_y = y + float(i) + # # Note: we did not enabled `check` mode, so changing this Python + # # value will not throw an exception. The new value has no influence + # # on the result even though without freezing, it would. + # # TODO: support "signature" detection and create separate frozen + # # function instances. + # new_flag = (i % 2) == 0 + # results = fn_frozen(x, new_y, flag=new_flag) + # assert dr.allclose(results[0], ref_results[0]) + # assert results[1] is None + # assert dr.allclose(results[2], new_y) + # assert dr.allclose(results[3], x * 0.5 + new_y) + + +@pytest.test_arrays("float32, jit, shape=(*)") +def test34_global_flag(t): + Float = t + + @dr.freeze + def my_fn(a, b, c=0.5): + return a + b + c + + # Recording + one = Float([1.0] * 9) + result1 = my_fn(one, one, c=0.1) + assert dr.allclose(result1, 2.1) + + # Can change the type of an input + result2 = my_fn(one, one, c=Float(0.6)) + assert dr.allclose(result2, 2.6) + + assert my_fn.n_recordings == 2 + + # Disable frozen kernels globally, now the freezing + # logic should be completely bypassed + with dr.scoped_set_flag(dr.JitFlag.KernelFreezing, False): + result3 = my_fn(one, one, c=0.9) + assert dr.allclose(result3, 2.9) + + +# @pytest.mark.parametrize("struct_style", ["drjit", "dataclass"]) +@pytest.mark.parametrize("struct_style", ["drjit", "dataclass"]) +# @pytest.test_arrays("float32, llvm, jit, -is_diff, shape=(*)") +@pytest.test_arrays("float32, jit, shape=(*)") +def test35_return_types(t, struct_style): + # WARN: only working on CUDA! + mod = sys.modules[t.__module__] + Float = t + Array3f = mod.Array3f + UInt32 = mod.UInt32 + + import numpy as np + + if struct_style == "drjit": + + class ToyDataclass: + DRJIT_STRUCT: dict = {"a": Float, "b": Float} + a: Float + b: Float + + def __init__(self, a=None, b=None): + self.a = a + self.b = b + + else: + assert struct_style == "dataclass" + + @dataclass(frozen=True) + class ToyDataclass: + a: Float + b: Float + + # 1. Many different types + @dr.freeze + def toy1(x: Float) -> Float: + y = x**2 + dr.sin(x) + z = x**2 + dr.cos(x) + return (x, y, z, ToyDataclass(a=x, b=y), {"x": x, "yi": UInt32(y)}, [[[[x]]]]) + + for i in range(2): + input = Float(np.full(17, i)) + # input = dr.full(Float, i, 17) + result = toy1(input) + assert isinstance(result[0], Float) + assert isinstance(result[1], Float) + assert isinstance(result[2], Float) + assert isinstance(result[3], ToyDataclass) + assert isinstance(result[4], dict) + assert result[4].keys() == set(("x", "yi")) + assert isinstance(result[4]["x"], Float) + assert isinstance(result[4]["yi"], UInt32) + assert isinstance(result[5], list) + assert isinstance(result[5][0], list) + assert isinstance(result[5][0][0], list) + assert isinstance(result[5][0][0][0], list) + + # 2. Many different types + @dr.freeze + def toy2(x: Float, target: Float) -> Float: + dr.scatter(target, 0.5 + x, dr.arange(UInt32, dr.width(x))) + return None + + for i in range(3): + input = Float([i] * 17) + target = dr.opaque(Float, 0, dr.width(input)) + # target = dr.full(Float, 0, dr.width(input)) + # target = dr.empty(Float, dr.width(input)) + + result = toy2(input, target) + assert dr.allclose(target, 0.5 + input) + assert result is None + + # 3. DRJIT_STRUCT as input and returning nested dictionaries + @dr.freeze + def toy3(x: Float, y: ToyDataclass) -> Float: + x_d = dr.detach(x, preserve_type=False) + return { + "a": x, + "b": (x, UInt32(2 * y.a + y.b)), + "c": None, + "d": { + "d1": x + x, + "d2": Array3f(x_d, -x_d, 2 * x_d), + "d3": None, + "d4": {}, + "d5": tuple(), + "d6": list(), + "d7": ToyDataclass(a=x, b=2 * x), + }, + "e": [x, {"e1": None}], + } + + for i in range(3): + input = Float([i] * 5) + input2 = ToyDataclass(a=input, b=Float(4.0)) + result = toy3(input, input2) + assert isinstance(result, dict) + assert isinstance(result["a"], Float) + assert isinstance(result["b"], tuple) + assert isinstance(result["b"][0], Float) + assert isinstance(result["b"][1], UInt32) + assert result["c"] is None + assert isinstance(result["d"], dict) + assert isinstance(result["d"]["d1"], Float) + assert isinstance(result["d"]["d2"], Array3f) + assert result["d"]["d3"] is None + assert isinstance(result["d"]["d4"], dict) and len(result["d"]["d4"]) == 0 + assert isinstance(result["d"]["d5"], tuple) and len(result["d"]["d5"]) == 0 + assert isinstance(result["d"]["d6"], list) and len(result["d"]["d6"]) == 0 + assert isinstance(result["d"]["d7"], ToyDataclass) + assert dr.allclose(result["d"]["d7"].a, input) + assert dr.allclose(result["d"]["d7"].b, 2 * input) + assert isinstance(result["e"], list) + assert isinstance(result["e"][0], Float) + assert isinstance(result["e"][1], dict) + assert result["e"][1]["e1"] is None + + +@pytest.test_arrays("float32, jit, shape=(*)") +def test36_drjit_struct_and_matrix(t): + package = sys.modules[t.__module__] + Float = package.Float + Array4f = package.Array4f + Matrix4f = package.Matrix4f + + class MyTransform4f: + DRJIT_STRUCT = { + "matrix": Matrix4f, + "inverse": Matrix4f, + } + + def __init__(self, matrix: Matrix4f = None, inverse: Matrix4f = None): + self.matrix = matrix + self.inverse = inverse + + @dataclass(frozen=False) + class Camera: + to_world: MyTransform4f + + @dataclass(frozen=False) + class Batch: + camera: Camera + value: float = 0.5 + offset: float = 0.5 + + @dataclass(frozen=False) + class Result: + value: Float + constant: int = 5 + + def fun(batch: Batch, x: Array4f): + res1 = batch.camera.to_world.matrix @ x + res2 = batch.camera.to_world.matrix @ x + batch.offset + res3 = batch.value + x + res4 = Result(value=batch.value) + return res1, res2, res3, res4 + + fun_frozen = dr.freeze(fun) + + n = 7 + for i in range(4): + x = Array4f( + *(dr.linspace(Float, 0, 1, n) + dr.opaque(Float, i) + k for k in range(4)) + ) + mat = Matrix4f( + *( + dr.linspace(Float, 0, 1, n) + dr.opaque(Float, i) + ii + jj + for jj in range(4) + for ii in range(4) + ) + ) + trafo = MyTransform4f() + trafo.matrix = mat + trafo.inverse = dr.rcp(mat) + + batch = Batch( + camera=Camera(to_world=trafo), + value=dr.linspace(Float, -1, 0, n) - dr.opaque(Float, i), + ) + # dr.eval(x, trafo, batch.value) + + results = fun_frozen(batch, x) + expected = fun(batch, x) + + assert len(results) == len(expected) + for result_i, (value, expected) in enumerate(zip(results, expected)): + + assert type(value) == type(expected) + if isinstance(value, Result): + value = value.value + expected = expected.value + assert dr.allclose(value, expected), str(result_i) + + +@pytest.test_arrays("float32, jit, shape=(*)") +def test37_with_dataclass_in_out(t): + mod = sys.modules[t.__module__] + Int32 = mod.Int32 + UInt32 = mod.UInt32 + Bool = mod.Bool + + @dataclass(frozen=False) + class MyRecord: + step_in_segment: Int32 + total_steps: UInt32 + short_segment: Bool + + def acc_fn(record: MyRecord): + record.step_in_segment += Int32(2) + return Int32(record.total_steps + record.step_in_segment) + + # Initialize MyRecord + n_rays = 100 + record = MyRecord( + step_in_segment=UInt32([1] * n_rays), + total_steps=UInt32([0] * n_rays), + short_segment=dr.zeros(Bool, n_rays), + ) + + # Create frozen kernel that contains another function + frozen_acc_fn = dr.freeze(acc_fn) + + accumulation = dr.zeros(UInt32, n_rays) + n_iter = 12 + for _ in range(n_iter): + accumulation += frozen_acc_fn(record) + + expected = 0 + for i in range(n_iter): + expected += 0 + 2 * (i + 1) + 1 + assert dr.all(accumulation == expected) + + +@pytest.test_arrays("float32, jit, shape=(*)") +def test38_allocated_scratch_buffer(t): + """ + Frozen functions may want to allocate some scratch space, scatter to it + in a first kernel, and read / use the values later on. As long as the + size of the scratch space can be guessed (e.g. a multiple of the launch width, + or matching the width of an existing input), we should be able to support it. + + On the other hand, the "scattering to an unknown buffer" pattern may actually + be scattering to an actual pre-existing buffer, which the user simply forgot + to include in the `state` lambda. In order to catch that case, we at least + check that the "scratch buffer" was read from in one of the kernels. + Otherwise, we assume it was meant as a side-effect into a pre-existing buffer. + """ + mod = sys.modules[t.__module__] + # dr.set_flag(dr.JitFlag.KernelFreezing, False) + UInt32 = mod.UInt32 + + # Note: we are going through an object / method, otherwise the closure + # checker would already catch the `forgotten_target_buffer` usage. + class Model: + DRJIT_STRUCT = { + "some_state": UInt32, + # "forgotten_target_buffer": UInt32, + } + + def __init__(self): + self.some_state = UInt32([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + self.forgotten_target_buffer = self.some_state + 1 + dr.eval(self.some_state, self.forgotten_target_buffer) + + @dr.freeze + def fn1(self, x): + # Note: assuming here that the width of `forgotten_target_buffer` doesn't change + index = dr.arange(UInt32, dr.width(x)) % dr.width( + self.forgotten_target_buffer + ) + dr.scatter(self.forgotten_target_buffer, x, index) + + return 2 * x + + @dr.freeze + def fn2(self, x): + # Scratch buffer with width equal to a state variable + scratch = dr.zeros(UInt32, dr.width(self.some_state)) + # Kernel 1: write to `scratch` + index = dr.arange(UInt32, dr.width(x)) % dr.width(self.some_state) + dr.scatter(scratch, x, index) + # Kernel 2: use values from `scratch` directly + result = dr.square(scratch) + # We don't actually return `scratch`, its lifetime is limited to the frozen function. + return result + + @dr.freeze + def fn3(self, x): + # Scratch buffer with width equal to a state variable + scratch = dr.zeros(UInt32, dr.width(self.some_state)) + # Kernel 1: write to `scratch` + index = dr.arange(UInt32, dr.width(x)) % dr.width(self.some_state) + dr.scatter(scratch, x, index) + # Kernel 2: use values from `scratch` via a gather + result = x + dr.gather(UInt32, scratch, index) + # We don't actually return `scratch`, its lifetime is limited to the frozen function. + return result + + model = Model() + + # Suspicious usage, should not allow it to avoid silent surprising behavior + for i in range(4): + x = UInt32(list(range(i + 2))) + assert dr.width(x) < dr.width(model.forgotten_target_buffer) + + if dr.flag(dr.JitFlag.KernelFreezing): + with pytest.raises(RuntimeError): + result = model.fn1(x) + break + + else: + result = model.fn1(x) + assert dr.allclose(result, 2 * x) + + expected = UInt32(model.some_state + 1) + dr.scatter(expected, x, dr.arange(UInt32, dr.width(x))) + assert dr.allclose(model.forgotten_target_buffer, expected) + + # Expected usage, we should allocate the buffer and allow the launch + for i in range(4): + x = UInt32(list(range(i + 2))) # i+1 + assert dr.width(x) < dr.width(model.some_state) + result = model.fn2(x) + expected = dr.zeros(UInt32, dr.width(model.some_state)) + dr.scatter(expected, x, dr.arange(UInt32, dr.width(x))) + assert dr.allclose(result, dr.square(expected)) + + # Expected usage, we should allocate the buffer and allow the launch + for i in range(4): + x = UInt32(list(range(i + 2))) # i+1 + assert dr.width(x) < dr.width(model.some_state) + result = model.fn3(x) + assert dr.allclose(result, 2 * x) + + +@pytest.test_arrays("float32, jit, shape=(*)") +def test39_simple_reductions(t): + import numpy as np + + mod = sys.modules[t.__module__] + Float = mod.Float32 + n = 37 + + @dr.freeze + def simple_sum(x): + return dr.sum(x) + + @dr.freeze + def simple_product(x): + return dr.prod(x) + + @dr.freeze + def simple_min(x): + return dr.min(x) + + @dr.freeze + def simple_max(x): + return dr.max(x) + + @dr.freeze + def sum_not_returned_wide(x): + return dr.sum(x) + x + + @dr.freeze + def sum_not_returned_single(x): + return dr.sum(x) + 4 + + def check_expected(fn, expected): + result = fn(x) + + assert dr.width(result) == dr.width(expected) + assert isinstance(result, Float) + assert dr.allclose(result, expected) + + for i in range(3): + x = dr.linspace(Float, 0, 1, n) + dr.opaque(Float, i) + + x_np = x.numpy() + check_expected(simple_sum, np.sum(x_np).item()) + check_expected(simple_product, np.prod(x_np).item()) + check_expected(simple_min, np.min(x_np).item()) + check_expected(simple_max, np.max(x_np).item()) + + check_expected(sum_not_returned_wide, np.sum(x_np).item() + x) + check_expected(sum_not_returned_single, np.sum(x_np).item() + 4) + + +@pytest.test_arrays("float32, jit, shape=(*)") +def test40_prefix_reductions(t): + import numpy as np + + mod = sys.modules[t.__module__] + Float = mod.Float32 + n = 37 + + @dr.freeze + def prefix_sum(x): + return dr.prefix_reduce(dr.ReduceOp.Add, x, exclusive=False) + + def check_expected(fn, expected): + result = fn(x) + + assert dr.width(result) == dr.width(expected) + assert isinstance(result, Float) + assert dr.allclose(result, expected) + + for i in range(3): + x = dr.linspace(Float, 0, 1, n) + dr.opaque(Float, i) + + x_np = x.numpy() + check_expected(prefix_sum, Float(np.cumsum(x_np))) + + +@pytest.test_arrays("float32, jit, is_diff, shape=(*)") +def test41_reductions_with_ad(t): + Float = t + n = 37 + + @dr.freeze + def sum_with_ad(x, width_opaque): + intermediate = 2 * x + 1 + dr.enable_grad(intermediate) + + result = dr.square(intermediate) + + # Unfortunately, as long as we don't support creating opaque values + # within a frozen kernel, we can't use `dr.mean()` directly. + loss = dr.sum(result) / width_opaque + # loss = dr.mean(result) + dr.backward(loss) + + return result, intermediate + + @dr.freeze + def product_with_ad(x): + dr.enable_grad(x) + loss = dr.prod(x) + dr.backward_from(loss) + + for i in range(3): + x = dr.linspace(Float, 0, 1, n + i) + dr.opaque(Float, i) + result, intermediate = sum_with_ad(x, dr.opaque(Float, dr.width(x))) + assert dr.width(result) == n + i + + assert dr.grad_enabled(result) + assert dr.grad_enabled(intermediate) + assert not dr.grad_enabled(x) + intermediate_expected = 2 * x + 1 + assert dr.allclose(intermediate, intermediate_expected) + assert dr.allclose(result, dr.square(intermediate_expected)) + assert sum_with_ad.n_recordings == 1 + assert dr.allclose(dr.grad(result), 0) + assert dr.allclose( + dr.grad(intermediate), 2 * intermediate_expected / dr.width(x) + ) + + for i in range(3): + x = dr.linspace(Float, 0.1, 1, n + i) + dr.opaque(Float, i) + result = product_with_ad(x) + + assert result is None + assert dr.grad_enabled(x) + with dr.suspend_grad(): + expected_grad = dr.prod(x) / x + assert dr.allclose(dr.grad(x), expected_grad) + + +# @pytest.test_arrays("float32, jit, shape=(*)") +# def test35_mean(t): +# def func(x): +# return dr.mean(x) +# +# frozen_func = dr.freeze(func) +# +# n = 10 +# +# for i in range(3): +# x = dr.linspace(t, 0, 1, n + i) + dr.opaque(t, i) +# +# result = frozen_func(x) +# assert dr.allclose(result, func(x)) + + +@pytest.test_arrays("float32, jit, shape=(*)") +def test42_size_aliasing(t): + def func(x, y): + return x + 1, y + 2 + + frozen_func = dr.freeze(func) + + n = 3 + + for i in range(3): + x = dr.linspace(t, 0, 1, n) + dr.opaque(t, i) + y = dr.linspace(t, 0, 1, n + i) + dr.opaque(t, i) + + result = frozen_func(x, y) + + assert dr.allclose(result, func(x, y)) + + """ + We should have two recordings, one for which the experssions x+1 and y+1 are compiled into the same kernel, + because x and y have the same size and one where they are compiled seperately because their sizes are different. + """ + assert frozen_func.n_recordings == 2 + + +@pytest.test_arrays("float32, jit, -is_diff, shape=(*)") +def test43_pointer_aliasing(t): + """ + Dr.Jit employs a memory cache, which means that two variables + get allocated the same memory region, if one is destroyed + before the other is created. + Since we track variables using their pointers in the `RecordThreadState`, + we have to update the `ptr_to_slot` mapping for new variables. + """ + + n = 4 + + def func(x): + y = x + 1 + dr.make_opaque(y) + for i in range(3): + y = y + 1 + dr.make_opaque(y) + return y + + for i in range(10): + frozen_func = dr.freeze(func) + + x = dr.linspace(t, 0, 1, n) + dr.opaque(t, i) + assert dr.allclose(frozen_func(x), func(x)) + + x = dr.linspace(t, 0, 1, n) + dr.opaque(t, i) + assert dr.allclose(frozen_func(x), func(x)) + + +@pytest.test_arrays("float32, jit, is_diff, shape=(*)") +def test44_simple_ad_fully_inside(t): + mod = sys.modules[t.__module__] + Float = mod.Float + + def my_kernel(x): + dr.enable_grad(x) + + result = x * x + dr.backward(result) + + return result + + for start_enabled in (True, False): + # Re-freeze + my_kernel_frozen = dr.freeze(my_kernel) + + for i in range(3): + x = Float([1.0, 2.0, 3.0]) + dr.opaque(Float, i) + if start_enabled: + dr.enable_grad(x) + + y = my_kernel_frozen(x) + grad_x = dr.grad(x) + grad_y = dr.grad(y) + dr.schedule(y, grad_x, grad_y) + assert dr.allclose(y, dr.square(x)) + assert dr.allclose(grad_y, 0) + assert dr.allclose(grad_x, 2 * x) + + # Status of grad_enabled should be restored (side-effect of the function), + # even if it wasn't enabled at first + assert dr.grad_enabled(x) + + +@pytest.mark.parametrize("set_some_literal_grad", (False,)) +@pytest.mark.parametrize("inputs_end_enabled", (True, False)) +@pytest.mark.parametrize("inputs_start_enabled", (True,)) +@pytest.mark.parametrize("params_end_enabled", (False, False)) +@pytest.mark.parametrize("params_start_enabled", (True,)) +@pytest.mark.parametrize("freeze", (True,)) +@pytest.test_arrays("float32, jit, is_diff, shape=(*)") +def test45_suspend_resume( + t, + params_start_enabled, + params_end_enabled, + inputs_start_enabled, + inputs_end_enabled, + set_some_literal_grad, + freeze, +): + + mod = sys.modules[t.__module__] + Float = mod.Float32 + UInt32 = mod.UInt32 + log_level = dr.log_level() + + # TODO: remove this + # dr.set_flag(dr.JitFlag.KernelFreezing, False) + + class MyModel: + DRJIT_STRUCT = {"params": Float} + + def __init__(self, params): + self.params = params + self.frozen_eval = dr.freeze(type(self).eval) if freeze else type(self).eval + + def eval( + self, + x: Float, + params_end_enabled: bool, + inputs_end_enabled: bool, + set_some_literal_grad: bool, + ): + idx = dr.arange(UInt32, dr.width(x)) % dr.width(self.params) + latents = dr.gather(Float, self.params, idx) + result = x * latents + + with dr.resume_grad(): + dr.set_grad_enabled(self.params, params_end_enabled) + dr.set_grad_enabled(x, inputs_end_enabled) + if set_some_literal_grad: + # If grads are not enabled, this will get ignored, which is fine + dr.set_grad(x, Float(6.66)) + + return result + + model = MyModel(params=Float([1, 2, 3, 4, 5])) + + for i in range(3): + # Inputs of different widths + x = Float([0.1, 0.2, 0.3, 0.4, 0.5, 0.6] * (i + 1)) + dr.opaque(Float, i) + + dr.set_grad_enabled(model.params, params_start_enabled) + dr.set_grad_enabled(x, inputs_start_enabled) + + dr.eval(x, dr.grad(x)) + + with dr.suspend_grad(): + result = model.frozen_eval( + model, x, params_end_enabled, inputs_end_enabled, set_some_literal_grad + ) + + # dr.eval(result, model.params, dr.grad(model.params)) + assert not dr.grad_enabled(result) + assert dr.grad_enabled(model.params) == params_end_enabled + assert dr.grad_enabled(x) == inputs_end_enabled + + # The frozen function should restore the right width, even for a zero-valued literal. + # The default gradients are a zero-valued literal array + # with a width equal to the array's width + grads = dr.grad(model.params) + assert dr.width(grads) == dr.width(model.params) + assert dr.all(grads == 0) + + grads = dr.grad(x) + assert dr.width(grads) == dr.width(x) + if inputs_end_enabled and set_some_literal_grad: + assert dr.all(grads == 6.66) + else: + assert dr.all(grads == 0) + + assert model.frozen_eval.n_recordings == 1 + + +@pytest.test_arrays("float32, jit, is_diff, shape=(*)") +@pytest.mark.parametrize("freeze", (True,)) +@pytest.mark.parametrize("change_params_width", (False,)) +def test46_with_grad_scatter(t, freeze: bool, change_params_width): + mod = sys.modules[t.__module__] + Float = mod.Float32 + UInt32 = mod.UInt32 + log_level = dr.log_level() + + class Model: + DRJIT_STRUCT = {"params": Float} + + def __init__(self, n): + self.params = Float(list(range(1, n + 1))) + assert dr.width(self.params) == n + dr.enable_grad(self.params) + + def __call__(self): + # Cheeky workaround for the frozen kernel signature checking + pass + + def my_kernel(model, x, opaque_params_width): + idx = dr.arange(UInt32, dr.width(x)) % opaque_params_width + + with dr.resume_grad(): + latents = dr.gather(Float, model.params, idx) + contrib = x * latents + dr.backward_from(contrib) + + return dr.detach(contrib) + + model = Model(5) + my_kernel_frozen = dr.freeze(my_kernel) if freeze else my_kernel + + for i in range(6): + # Different width at each iteration + x = Float([1.0, 2.0, 3.0] * (i + 1)) + dr.opaque(Float, i) + + # The frozen kernel should also support the params (and therefore its gradient buffer) + # changing width without issues. + if change_params_width and (i == 3): + model = Model(10) + # Reset gradients + dr.set_grad(model.params, 0) + assert dr.grad_enabled(model.params) + + with dr.suspend_grad(): + y = my_kernel_frozen(model, x, dr.opaque(UInt32, dr.width(model.params))) + assert not dr.grad_enabled(x) + assert not dr.grad_enabled(y) + assert dr.grad_enabled(model.params) + + grad_x = dr.grad(x) + grad_y = dr.grad(y) + grad_p = dr.grad(model.params) + # assert dr.allclose(y, dr.sqr(x)) + + # Expected grads + assert dr.allclose(grad_y, 0) + assert dr.allclose(grad_x, 0) + grad_p_expected = dr.zeros(Float, dr.width(model.params)) + idx = dr.arange(UInt32, dr.width(x)) % dr.width(model.params) + dr.scatter_reduce(dr.ReduceOp.Add, grad_p_expected, x, idx) + assert dr.allclose(grad_p, grad_p_expected) + + +@pytest.test_arrays("float32, jit, is_diff, shape=(*)") +def test47_tutorial_example(t): + mod = sys.modules[t.__module__] + Float = mod.Float32 + UInt32 = mod.UInt32 + + @dr.freeze + def frozen_eval(inputs, idx, params, target_value, grad_factor): + intermediate = dr.gather(Float, params, idx) + result = 0.5 * dr.square(intermediate) * inputs + + # Since reductions are not supported yet, we cannot compute a single + # loss value here. It's not really a problem though, since DrJit can + # backpropagate starting from arrays of any widths. + loss_per_entry = dr.square(result - target_value) * grad_factor + + # The gradients resulting from backpropagation will be directly accumulated + # (via dr.scatter_add()) into the gradient buffer of `params` (= `dr.grad(params)`). + dr.backward_from(loss_per_entry) + + # It's fine to return the primal values of `result`, but keep in mind that they will + # not be differentiable w.r.t. `params`. + return dr.detach(result) + + params = Float([1, 2, 3, 4, 5]) + + for _ in range(3): + dr.disable_grad(params) + dr.enable_grad(params) + assert dr.all(dr.grad(params) == 0) + + inputs = Float([0.1, 0.2, 0.3]) + idx = UInt32([1, 2, 3]) + # Represents the optimizer's loss scale + grad_factor = 4096 / dr.opaque(Float, dr.width(inputs)) + + result = frozen_eval( + inputs, idx, params, target_value=0.5, grad_factor=grad_factor + ) + assert not dr.grad_enabled(result) + # Gradients were correctly accumulated to `params`'s gradients. + assert not dr.all(dr.grad(params) == 0) + + +@pytest.test_arrays("uint32, jit, shape=(*)") +def test48_compress(t): + + mod = sys.modules[t.__module__] + Float = mod.Float32 + UInt32 = mod.UInt32 + + pkg = get_pkg(t) + Sampler = pkg.Sampler + + def func(sampler: Sampler) -> UInt32: + indices = dr.compress(sampler.next() < 0.5) + return indices + + frozen = dr.freeze(func) + + sampler_func = Sampler(10) + sampler_frozen = Sampler(10) + for i in range(3): + dr.all(frozen(sampler_frozen) == func(sampler_func)) + + sampler_func = Sampler(11) + sampler_frozen = Sampler(11) + for i in range(3): + dr.all(frozen(sampler_frozen) == func(sampler_func)) + + assert frozen.n_recordings == 1 + + +@pytest.test_arrays("uint32, llvm, -is_diff, jit, shape=(*)") +def test49_scatter_reduce_expanded(t): + + def func(target: t, src: t): + dr.scatter_reduce(dr.ReduceOp.Add, target, src, dr.arange(t, dr.width(src)) % 2) + + frozen = dr.freeze(func) + + for i in range(4): + src = dr.full(t, 1, 10 + i) + dr.make_opaque(src) + + result = t([0] * (i + 2)) + dr.make_opaque(result) + frozen(result, src) + + reference = t([0] * (i + 2)) + dr.make_opaque(reference) + func(reference, src) + + assert dr.all(result == reference) + + assert frozen.n_cached_recordings == 1 + assert frozen.n_recordings == 4 + + +@pytest.test_arrays("uint32, llvm, -is_diff, jit, shape=(*)") +def test50_scatter_reduce_expanded_identity(t): + + def func(src: t): + target = dr.zeros(t, 5) + dr.scatter_reduce(dr.ReduceOp.Add, target, src, dr.arange(t, dr.width(src)) % 2) + return target + + frozen = dr.freeze(func) + + for i in range(4): + src = dr.full(t, 1, 10 + i) + dr.make_opaque(src) + + result = frozen(src) + + reference = func(src) + + assert dr.all(result == reference) + + assert frozen.n_recordings == 1 + + +@pytest.test_arrays("uint32, llvm, -is_diff, jit, shape=(*)") +def test51_scatter_reduce_expanded_no_memset(t): + + def func(src: t): + target = dr.full(t, 5) + dr.scatter_reduce(dr.ReduceOp.Add, target, src, dr.arange(t, dr.width(src)) % 2) + return target + + frozen = dr.freeze(func) + + for i in range(4): + src = dr.full(t, 1, 10 + i) + dr.make_opaque(src) + + result = frozen(src) + + reference = func(src) + + assert dr.all(result == reference) + + assert frozen.n_recordings == 1 + assert frozen.n_cached_recordings == 1 + + +@pytest.test_arrays("uint32, jit, shape=(*)") +def test52_python_inputs(t): + + def func(x: t, neg: bool): + if neg: + return -x + 1 + else: + return x + 1 + + frozen = dr.freeze(func) + + for i in range(3): + for neg in [False, True]: + x = t(1, 2, 3) + dr.opaque(t, i) + + res = frozen(x, neg) + ref = func(x, neg) + assert dr.all(res == ref) + + assert frozen.n_recordings == 2 + + +@pytest.test_arrays("uint32, jit, shape=(*)") +def test53_scatter_inc(t): + + mod = sys.modules[t.__module__] + Float = mod.Float32 + UInt32 = mod.UInt32 + + n = 37 + + def acc_with_scatter_inc(x, counter, out_buffer, max_points): + active = x > 0.5 + out_idx = dr.scatter_inc(counter, UInt32(0), active=active) + active &= out_idx < max_points + + # TODO: also test within a loop + dr.scatter(out_buffer, x, out_idx, active=active) + + def test(i, func): + x = dr.linspace(Float, 0.1, 1, n + i) + dr.opaque(Float, i) / 100 + counter = UInt32(dr.opaque(UInt32, 0)) + out_buffer = dr.zeros(Float, 10) + max_points = UInt32(dr.opaque(UInt32, dr.width(out_buffer))) + + dr.set_label(x, "x") + dr.set_label(counter, "counter") + dr.set_label(out_buffer, "out_buffer") + dr.set_label(max_points, "max_points") + dr.eval(x, counter, out_buffer, max_points) + + func(x, counter, out_buffer, max_points) + + return out_buffer, counter + + def func(i): + return test(i, acc_with_scatter_inc) + + acc_with_scatter_inc = dr.freeze(acc_with_scatter_inc) + + def frozen(i): + return test(i, acc_with_scatter_inc) + + for i in range(3): + + res, _ = frozen(i) + ref, _ = func(i) + + # assert dr.all(res == ref) + + # Should have filled all of the entries + assert dr.all(res > 0.5) + + +@pytest.test_arrays("uint32, jit, shape=(*)") +def test54_read_while_frozen(t): + # dr.set_flag(dr.JitFlag.KernelFreezing, True) + assert dr.flag(dr.JitFlag.KernelFreezing) + + def func(x): + return x[1] + + frozen = dr.freeze(func) + + x = t(1, 2, 3) + with pytest.raises(RuntimeError): + frozen(x) + + +@pytest.test_arrays("float32, jit, diff, shape=(*)") +def test55_var_upload(t): + def func(x): + + arrays = [] + + for i in range(3): + y = dr.arange(t, 3) + dr.make_opaque(y) + arrays.append(y) + + del arrays + del y + + return x / t(10, 10, 10) + + frozen = dr.freeze(func) + + for i in range(3): + x = dr.arange(t, 3) + dr.make_opaque(x) + + # with pytest.raises(RuntimeError, match = "created while recording"): + with pytest.raises(RuntimeError): + z = frozen(x) + + # assert dr.allclose(z, func(x)) + + +@pytest.test_arrays("float32, jit, diff, shape=(*)") +def test56_grad_isolate(t): + dr.set_flag(dr.JitFlag.ReuseIndices, False) + + def f(x): + return x * 2 + + def g(y): + return y * 3 + + def func(y): + z = g(y) + dr.backward(z) + + frozen = dr.freeze(func) + + for i in range(3): + + x = dr.arange(t, 3) + dr.make_opaque(x) + dr.enable_grad(x) + + y = f(x) + with dr.isolate_grad(): + func(y) + + ref = dr.grad(x) + + x = dr.arange(t, 3) + dr.make_opaque(x) + dr.enable_grad(x) + + y = f(x) + dr.make_opaque(y) + frozen(y) + + res = dr.grad(x) + + assert dr.allclose(ref, res) + + +@pytest.test_arrays("float32, jit, diff, shape=(*)") +def test57_isolate_grad_fwd(t): + + def f(x): + return x * x + + def g(y): + return y * 2 + + def func(x): + with dr.isolate_grad(): + y = f(x) + dr.forward(x) + return y + + def frozen(x): + y = f(x) + dr.forward(x) + return y + + frozen = dr.freeze(frozen) + + for i in range(3): + x = t(i) + dr.make_opaque(x) + dr.enable_grad(x) + + y = func(x) + # z = g(y) + + ref = dr.grad(y) + + x = t(i) + dr.make_opaque(x) + dr.enable_grad(x) + + y = frozen(x) + # z = g(y) + + res = dr.grad(y) + + assert dr.allclose(ref, res) + + +@pytest.test_arrays("float32, jit, diff, shape=(*)") +def test58_grad_postponed_part(t): + dr.set_flag(dr.JitFlag.ReuseIndices, False) + + def f(x): + return x * x * 2 + + def g(y): + return y * y * 3 + + def func(y1, y2): + z1 = g(y1) + z2 = g(y2) + dr.backward(z1) + + frozen = dr.freeze(func) + + def run(i, name, func): + x1 = dr.arange(t, 3) + i + dr.make_opaque(x1) + dr.enable_grad(x1) + y1 = f(x1) + + x2 = dr.arange(t, 3) + i + dr.make_opaque(x2) + dr.enable_grad(x2) + dr.set_grad(x2, 2) + y2 = f(x2) + + func(y1, y2) + + dx1 = dr.grad(x1) + dx2_1 = dr.grad(x2) + + dr.set_grad(x2, 1) + dr.backward(x2) + dx1_2 = dr.grad(x2) + + return [dx1, dx1_2, dx2_1] + + for i in range(3): + + def isolated(y1, y2): + with dr.isolate_grad(): + func(y1, y2) + + ref = run(i, "reference", isolated) + res = run(i, "frozen", frozen) + + for ref, res in zip(ref, res): + assert dr.allclose(ref, res) + + +@pytest.test_arrays("float32, jit, diff, shape=(*)") +def test59_nested(t): + + pkg = get_pkg(t) + mod = sys.modules[t.__module__] + + A, B, Base, BasePtr = pkg.A, pkg.B, pkg.Base, pkg.BasePtr + a, b = A(), B() + a.value = dr.ones(t, 16) + dr.enable_grad(a.value) + + U = mod.UInt32 + xi = t(1, 2, 8, 3, 4) + yi = dr.reinterpret_array(U, BasePtr(a, a, a, a, a)) + + def nested(self, xi, yi): + return self.nested(xi, yi) + + def func(c, xi, yi): + return dr.dispatch(c, nested, xi, yi) + + frozen = dr.freeze(func) + + for i in range(3): + c = BasePtr(a, a, a, b, b) + xi = t(1, 2, 8, 3, 4) + yi = dr.reinterpret_array(U, BasePtr(a, a, a, a, a)) + + with dr.scoped_set_flag(dr.JitFlag.SymbolicCalls, True): + xref = func(c, xi, yi) + + assert dr.all(xref == xi + 1) + + c = BasePtr(a, a, a, b, b) + xi = t(1, 2, 8, 3, 4) + yi = dr.reinterpret_array(U, BasePtr(a, a, a, a, a)) + + with dr.scoped_set_flag(dr.JitFlag.SymbolicCalls, True): + xfrozen = frozen(c, xi, yi) + + assert dr.all(xfrozen == xref) + + +@pytest.test_arrays("float32, jit, diff, shape=(*)") +def test60_call_raise(t): + + mod = sys.modules[t.__module__] + pkg = get_pkg(t) + + UInt = mod.UInt + + A, B, Base, BasePtr = pkg.A, pkg.B, pkg.Base, pkg.BasePtr + a, b = A(), B() + + def f(x: t): + raise RuntimeError("test") + + def g(self, x: t): + if isinstance(self, B): + raise RuntimeError + return x + 1 + + c = BasePtr(a, a, a, b, b) + + with pytest.raises(RuntimeError): + dr.dispatch(c, g, t(1, 1, 2, 2, 2)) + + +@pytest.test_arrays("float32, jit, diff, shape=(*)") +def test61_reduce_dot(t): + def func(x, y): + return dr.dot(x, y) + + frozen = dr.freeze(func) + + for i in range(3): + x = dr.arange(t, 10 + i) + y = dr.arange(t, 10 + i) + + result = frozen(x, y) + reference = func(x, y) + + assert dr.allclose(result, reference) + + assert frozen.n_recordings == 1 + + +@pytest.test_arrays("float32, jit, shape=(*)") +def test62_clear(t): + @dr.freeze + def func(x): + return x + 1 + + x = dr.arange(t, 10) + y = func(x) + assert func.n_recordings == 1 + + func.clear() + assert func.n_recordings == 0 + + x = dr.arange(t, 10) + y = func(x) + assert func.n_recordings == 1 + + +@pytest.test_arrays("uint32, jit, shape=(*)") +def test63_method_decorator(t): + mod = sys.modules[t.__module__] + + class Custom: + DRJIT_STRUCT = {"state": t} + + def __init__(self) -> None: + self.state = t([1, 2, 3]) + + @dr.freeze + def frozen(self, x): + return x + self.state + + def func(self, x): + return x + self.state + + c = Custom() + for i in range(3): + x = dr.arange(t, 3) + i + dr.make_opaque(x) + res = c.frozen(x) + ref = c.func(x) + + assert dr.all(res == ref) + + +@pytest.test_arrays("float32, jit, shape=(*)") +def test64_tensor(t): + """ + Tests that constructing tensors in frozen functions is possible, and does + not cause leaks. + """ + mod = sys.modules[t.__module__] + Float32 = mod.Float32 + TensorXf = mod.TensorXf + + def func(x): + return TensorXf(x + 1) + + frozen = dr.freeze(func) + + for i in range(3): + x = dr.arange(Float32, 100) + ref = func(x) + res = frozen(x) + assert dr.all(res == ref) + + +@pytest.test_arrays("float32, jit, shape=(*)") +def test65_assign_tensor(t): + """ + Tests that assigning tensors to the input of frozen functions is possible, + and does not cause leaks. + """ + mod = sys.modules[t.__module__] + Float32 = mod.Float32 + TensorXf = mod.TensorXf + + def func(x): + x += 1 + + frozen = dr.freeze(func) + + for i in range(3): + x = TensorXf(dr.arange(Float32, 100)) + func(x) + ref = x + + x = TensorXf(dr.arange(Float32, 100)) + frozen(x) + res = x + assert dr.all(res == ref) + + +@pytest.test_arrays("float32, jit, shape=(*)") +def test66_closure(t): + + c1 = 1 + c2 = t(2) + + def func(x): + return x + c1 + c2 + + frozen = dr.freeze(func) + + for i in range(3): + x = dr.arange(t, i + 2) + ref = func(x) + + x = dr.arange(t, i + 2) + res = frozen(x) + + assert dr.allclose(ref, res) + + assert frozen.n_recordings == 1 + + +@pytest.test_arrays("float32, jit, shape=(*)") +def test67_mutable_closure(t): + """ + Test that it is possible to use and modify closures in frozen functions. + """ + y1 = t(1, 2, 3) + y2 = t(1, 2, 3) + + def func(x): + nonlocal y1 + y1 += x + + @dr.freeze + def frozen(x): + nonlocal y2 + y2 += x + + for i in range(3): + x = t(i) + dr.make_opaque(x) + + func(x) + frozen(x) + + assert dr.allclose(y1, y2) + + +@pytest.test_arrays("float32, jit, shape=(*)") +def test68_state_decorator(t): + mod = sys.modules[t.__module__] + Float = mod.Float32 + UInt32 = mod.UInt32 + + # Note: not a dataclass or DRJIT_STRUCT + class MyClass: + def __init__(self): + self.something1 = 4.5 + self.something2 = Float([1, 2, 3, 4, 5]) + + @dr.freeze(state_fn=lambda self, *_, **__: (self.something2)) + def frozen(self, x: Float, idx: UInt32) -> Float: + return x * self.something1 * dr.gather(Float, self.something2, idx) + + def func(self, x: Float, idx: UInt32) -> Float: + return x * self.something1 * dr.gather(Float, self.something2, idx) + + c = MyClass() + + for i in range(3): + x = dr.arange(Float, i + 2) + idx = dr.arange(UInt32, i + 2) + + res = c.frozen(x, idx) + ref = c.func(x, idx) + + assert dr.allclose(res, ref) + assert c.frozen.n_recordings == 1 + + +@pytest.test_arrays("float32, jit, shape=(*)") +@pytest.mark.parametrize("limit", (-1, None, 0, 1, 2)) +def test69_max_cache_size(t, limit): + """ + Tests different cache size limitations for the frozen function. + """ + + def func(x, p): + return x + p + + frozen = dr.freeze(func, limit=limit) + + n = 3 + for i in range(n): + + x = t(i, i+1, i+2) + + res = frozen(x, i) + ref = func(x, i) + + assert dr.allclose(res, ref) + + if limit == -1 or limit is None: + assert frozen.n_recordings == n + assert frozen.n_cached_recordings == n + elif limit == 0: + assert frozen.n_recordings == 0 + assert frozen.n_cached_recordings == 0 + else: + assert frozen.n_recordings == n + assert frozen.n_cached_recordings == limit + + +@pytest.test_arrays("float32, jit, shape=(*)") +def test69_lru_eviction(t): + """ + Tests that the least recently used cache entry is evicted from the frozen + function if the cache size is limited. + """ + + def func(x, p): + return x + p + + frozen = dr.freeze(func, limit=2) + + x = t(0, 1, 2) + + # Create two entries in the cache + frozen(x, 0) + frozen(x, 1) + + # This should evict the first one + frozen(x, 2) + + assert frozen.n_recordings == 3 + + # p = 1 should still be in the cache, and calling it should not increment + # the recording counter. + frozen(x, 1) + + assert frozen.n_recordings == 3 + + # p = 0 should be evicted, and calling it will increment the counter + frozen(x, 0) + + assert frozen.n_recordings == 4 + + +@pytest.test_arrays("float32, jit, shape=(*)") +def test70_warn_recordings(t): + """ + This test simply calls the frozen function with incompattible inputs, and + should print two warnings. + """ + + def func(x, i): + return x + i + + frozen = dr.freeze(func, warn_after=2) + + for i in range(4): + x = t(1, 2, 3) + frozen(x, i) + +@pytest.test_arrays("float32, jit, shape=(*)") +@pytest.mark.parametrize("force_optix", [True, False]) +def test71_texture(t, force_optix): + mod = sys.modules[t.__module__] + Texture1f = mod.Texture1f + Float = mod.Float32 + + def func(tex: Texture1f, pos: Float): + return tex.eval(pos) + + frozen = dr.freeze(func) + + with dr.scoped_set_flag(dr.JitFlag.ForceOptiX, force_optix): + n = 4 + for i in range(3): + tex = Texture1f([2], 1, True, dr.FilterMode.Linear, dr.WrapMode.Repeat) + tex.set_value(t(0, 1)) + + pos = dr.arange(Float, i+2) / n + + res = frozen(tex, pos) + ref = func(tex, pos) + + assert dr.allclose(res, ref) + + assert frozen.n_recordings < n + +@pytest.test_arrays("float32, jit, shape=(*)") +def test72_no_input(t): + mod = sys.modules[t.__module__] + + backend = dr.backend_v(t) + + def func(): + return dr.arange(t, 10) + + frozen = dr.freeze(func, backend = backend) + + for i in range(3): + res = frozen() + ref = func() + + assert dr.allclose(res, ref) + + wrong_backend = ( + dr.JitBackend.CUDA if backend == dr.JitBackend.LLVM else dr.JitBackend.LLVM + ) + + frozen = dr.freeze(func, backend=wrong_backend) + + with pytest.raises(RuntimeError): + for i in range(3): + res = frozen() + + +@pytest.test_arrays("float32, jit, shape=(*)") +def test76_changing_literal_width_holder(t): + + class MyHolder: + DRJIT_STRUCT = {"lit": t} + def __init__(self, lit): + self.lit = lit + + def func(x: t, lit: MyHolder): + return x + 1 + + # Note: only fails with auto_opaque=True + frozen = dr.freeze(func, warn_after=3) + + n = 10 + for i in range(n): + holder = MyHolder(dr.zeros(dr.tensor_t(t), (i+1) * 10)) + x = holder.lit + 0.5 + dr.make_opaque(x) + + res = frozen(x, holder) + ref = func(x, holder) + + assert dr.allclose(ref, res) + + assert frozen.n_recordings == 1 diff --git a/tests/while_loop_ext.cpp b/tests/while_loop_ext.cpp index 64613e7f3..e207a38a7 100644 --- a/tests/while_loop_ext.cpp +++ b/tests/while_loop_ext.cpp @@ -3,6 +3,7 @@ #include #include #include +#include namespace nb = nanobind; namespace dr = drjit; @@ -37,11 +38,13 @@ struct Sampler { T next() { return rng.next_float32(); } - void traverse_1_cb_ro(void *payload, void (*fn)(void *, uint64_t)) const { + void traverse_1_cb_ro(void *payload, + dr::detail::traverse_callback_ro fn) const { traverse_1_fn_ro(rng, payload, fn); } - void traverse_1_cb_rw(void *payload, uint64_t (*fn)(void *, uint64_t)) { + void traverse_1_cb_rw(void *payload, + dr::detail::traverse_callback_rw fn) { traverse_1_fn_rw(rng, payload, fn); } From 36a8dd3bc7ddd5fed19e006a9c4369581808bf44 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Christian=20D=C3=B6ring?= Date: Fri, 18 Apr 2025 12:25:20 +0200 Subject: [PATCH 2/3] Fixed freezing drjit optimizers --- drjit/opt.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/drjit/opt.py b/drjit/opt.py index 67b174cf8..39a0bd0d6 100644 --- a/drjit/opt.py +++ b/drjit/opt.py @@ -126,6 +126,11 @@ class Optimizer(Generic[Extra], MutableMapping[str, dr.ArrayBase]): # - an arbitrary sequence of additional optimizer-dependent state values state: Dict[str, Tuple[dr.ArrayBase, Optional[LearningRate], Extra]] + DRJIT_STRUCT = { + "lr": LearningRate, + "state": dict, + } + def __init__( self, lr: LearningRate, @@ -960,10 +965,15 @@ def _step( # Compute the step size scale, which is a product of # - EMA debiasing factor # - Adaptive/parameter-specific scaling + Float32 = dr.float32_array_t(dr.leaf_t(grad)) + Float64 = dr.float64_array_t(dr.leaf_t(grad)) + ema_factor = Float32( + -dr.sqrt(1 - Float64(self.beta_2) ** t) / (1 - Float64(self.beta_1) ** t) + ) scale = cache.product( dr.leaf_t(grad), # Desired type lr, - -dr.sqrt(1 - self.beta_2**t) / (1 - self.beta_1**t), + ema_factor, ) # Optional: use maximum of second order term @@ -981,9 +991,11 @@ def _step( def _reset(self, key: str, value: dr.ArrayBase, /) -> None: valarr = value.array tp = type(valarr) + UInt = dr.uint32_array_t(dr.leaf_t(tp)) + t = UInt(0) m_t = dr.opaque(tp, 0, valarr.shape) v_t = dr.opaque(tp, 0, valarr.shape) - self.state[key] = value, None, (0, m_t, v_t) + self.state[key] = value, None, (t, m_t, v_t) # Blend between the old and new versions of the optimizer extra state def _select( From 843f66872ec7777ad28c7dcb0b233b1c92675bd4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Christian=20D=C3=B6ring?= Date: Fri, 18 Apr 2025 14:06:48 +0200 Subject: [PATCH 3/3] Added optimizer freezing tests --- tests/test_freeze.py | 41 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/tests/test_freeze.py b/tests/test_freeze.py index fc34a3d40..c97833095 100644 --- a/tests/test_freeze.py +++ b/tests/test_freeze.py @@ -2822,3 +2822,44 @@ def func(x: t, lit: MyHolder): assert dr.allclose(ref, res) assert frozen.n_recordings == 1 + +@pytest.test_arrays("float32, jit, diff, shape=(*)") +@pytest.mark.parametrize("optimizer", ["sdg", "rmsprop", "adam"]) +def test77_optimizers(t, optimizer): + n = 10 + + def func(y, opt): + loss = dr.mean(dr.square(opt["x"] - y)) + + dr.backward(loss) + + opt.step() + + return opt["x"], loss + + def init_optimizer(): + if optimizer == "sdg": + opt = dr.opt.SGD(lr = 0.001, momentum = 0.9) + elif optimizer == "rmsprop": + opt = dr.opt.RMSProp(lr = 0.001) + elif optimizer == "adam": + opt = dr.opt.Adam(lr = 0.001) + return opt + + frozen = dr.freeze(func) + + opt_func = init_optimizer() + opt_frozen = init_optimizer() + + for i in range(n): + x = dr.full(t, 1, 10) + y = dr.full(t, 0, 10) + + opt_func["x"] = x + opt_frozen["x"] = x + + res_x, res_loss = frozen(y, opt_frozen) + ref_x, ref_loss = func(y, opt_func) + + assert dr.allclose(res_x, ref_x) + assert dr.allclose(res_loss, ref_loss)