diff --git a/drjit/__init__.py b/drjit/__init__.py index 81c907d4..18d60d06 100644 --- a/drjit/__init__.py +++ b/drjit/__init__.py @@ -19,12 +19,12 @@ import sys as _sys if _sys.version_info < (3, 11): try: - from typing_extensions import overload, Optional, Type, Tuple, List, Sequence, Union, Literal, Callable + from typing_extensions import overload, Optional, Type, Tuple, List, Sequence, Union, Literal, Callable, TypeVar except ImportError: raise RuntimeError( "Dr.Jit requires the 'typing_extensions' package on Python <3.11") else: - from typing import overload, Optional, Type, Tuple, List, Sequence, Union, Literal, Callable + from typing import overload, Optional, Type, Tuple, List, Sequence, Union, Literal, Callable, TypeVar from .ast import syntax, hint from .interop import wrap @@ -2494,6 +2494,256 @@ def binary_search(start, end, pred): return start +# Represents the frozen function passed to the decorator without arguments +F = TypeVar("F") +# Represents the frozen function passed to the decorator with arguments +F2 = TypeVar("F2") + +@overload +def freeze( + f: None = None, + *, + state_fn: Optional[Callable], + limit: Optional[int] = None, + warn_after: int = 10, + backend: Optional[JitBackend] = None, +) -> Callable[[F], F]: + """ + Decorator to "freeze" functions, which improves efficiency by removing + repeated JIT tracing overheads. + + In general, Dr.Jit traces computation and then compiles and launches kernels + containing this trace (see the section on :ref:`evaluation ` for + details). While the compilation step can often be skipped via caching, the + tracing cost can still be significant especially when repeatedly evaluating + complex models, e.g., as part of an optimization loop. + + The :py:func:`@dr.freeze ` decorator adresses this problem by + altogether removing the need to trace repeatedly. For example, consider the + following decorated function: + + .. code-block:: python + + @dr.freeze + def f(x, y, z): + return ... # Complicated code involving the arguments + + Dr.Jit will trace the first call to the decorated function ``f()``, while + collecting additional information regarding the nature of the function's inputs + and regarding the CPU/GPU kernel launches representing the body of ``f()``. + + If the function is subsequently called with *compatible* arguments (more on + this below), it will immediately launch the previously made CPU/GPU kernels + without re-tracing, which can substantially improve performance. + + When :py:func:`@dr.freeze ` detects *incompatibilities* (e.g., ``x`` + having a different type compared to the previous call), it will conservatively + re-trace the body and keep track of another potential input configuration. + + Frozen functions support arbitrary :ref:`PyTrees ` as function + arguments and return values. + + The following may trigger re-tracing: + + - Changes in the **type** of an argument or :ref:`PyTree ` element. + - Changes in the **length** of a container (``list``, ``tuple``, ``dict``). + - Changes of **dictionary keys** or **field names** of dataclasses. + - Changes in the AD status (:py:`dr.grad_enabled() `) of a variable. + - Changes of (non-PyTree) **Python objects**, as detected by mismatching ``hash()``. + + The following more technical conditions also trigger re-tracing: + - A Dr.Jit variable changes from/to a **scalar** configuration (size ``1``). + - The sets of variables of the same size change. In the example above, this + would be the case if ``len(x) == len(y)`` in one call, and ``len(x) != len(y)`` + subsequently. + - When Dr.Jit variables reference external memory (e.g. mapped NumPy arrays), the + memory can be aligned or unaligned. A re-tracing step is needed when this + status changes. + + These all correspond to situations where the generated kernel code may need to + change, and the system conservatively re-traces to ensure correctness. + + Frozen functions support arguments with a different variable *width* (see + :py:func:`dr.with() `) without re-tracing, as long as the sets of + variables of the same width stay consistent. + + Some constructions are problematic and should be avoided in frozen functions. + + - The function :py:func:`dr.width() ` returns an integer literal + that may be merged into the generated code. If the frozen function is later + rerun with differently-sized arguments, the executed kernels will still + reference the old size. One exception to this rule are constructions like + `dr.arange(UInt32, dr.width(a))`, where the result only implicitly depends on + the width value. + + **Advanced features**. The :py:func:`@dr.freeze ` decorator takes + several optional parameters that are helpful in certain situations. + + - **Warning when re-tracing happens too often**: Incompatible arguments trigger + re-tracing, which can mask issues where *accidentally* incompatible arguments + keep :py:func:`@dr.freeze ` from producing the expected + performance benefits. + + In such situations, it can be helpful to warn and identify changing + parameters by name. This feature is enabled and set to ``10`` by default. + + .. code-block:: pycon + + >>> @dr.freeze(warn_after=1) + >>> def f(x): + ... return x + ... + >>> f(Int(1)) + >>> f(Float(1)) + The frozen function has been recorded 2 times, this indicates a problem + with how the frozen function is being called. For example, calling it + with changing python values such as an index. For more information about + which variables changed set the log level to ``LogLevel::Debug``. + + - **Limiting memory usage**. Storing kernels for many possible input + configuration requires device memory, which can become problematic. Set the + ``limit=`` parameter to enable a LRU cache. This is useful when calls to a + function are mostly compatible but require occasional re-tracing. + + Args: + limit (Optional[int]): An optional integer specifying the maximum number of + stored configurations. Once this limit is reached, incompatible calls + requiring re-tracing will cause the last used configuration to be dropped. + + warn_after (int): When the number of re-tracing steps exceeds this value, + Dr.Jit will generate a warning that explains which variables changed + between calls to the function. + + state_fn (Optional[Callable]): This optional callable can specify additional + state to identifies the configuration. ``state_fn`` will be called with + the same arguments as that of the decorated function. It should return a + traversable object (e.g., a list or tuple) that is conceptually treated + as if it was another input of the function. + + backend (Optional[JitBackend]): If no inputs are given when calling the + frozen function, the backend used has to be specified using this argument. + It must match the backend used for computation within the function. + """ + + +@overload +def freeze( + f: F, + *, + state_fn: Optional[Callable] = None, + limit: Optional[int] = None, + warn_after: int = 10, + backend: Optional[JitBackend] = None, +) -> F: ... + + +def freeze( + f: Optional[F] = None, + *, + state_fn: Optional[Callable] = None, + limit: Optional[int] = None, + warn_after: int = 10, + backend: Optional[JitBackend] = None, +) -> Union[F, Callable[[F2], F2]]: + limit = limit if limit is not None else -1 + backend = backend if backend is not None else JitBackend.Invalid + + def decorator(f): + """ + Internal decorator, returned in ``dr.freeze`` was used with arguments. + """ + import functools + import inspect + + def inner(closure, *args, **kwargs): + """ + This inner function is the one that gets actually frozen. It receives + any additional state such as closures or state specified with the + ``state`` lambda, and allows for traversal of it. + """ + return f(*args, **kwargs) + + class FrozenFunction: + def __init__(self, f) -> None: + closure = inspect.getclosurevars(f) + self.closure = (closure.nonlocals, closure.globals) + self.frozen = detail.FrozenFunction( + inner, + limit, + warn_after, + backend, + ) + + def __call__(self, *args, **kwargs): + _state = state_fn(*args, **kwargs) if state_fn is not None else None + return self.frozen([self.closure, _state], *args, **kwargs) + + @property + def n_recordings(self): + """ + Represents the number of times the function was recorded. This + includes occasions where it was recorded due to a dry-run failing. + It does not necessarily correspond to the number of recordings + currently cached see ``n_cached_recordings`` for that. + """ + return self.frozen.n_recordings + + @property + def n_cached_recordings(self): + """ + Represents the number of recordings currently cached of the frozen + function. If a recording fails in dry-run mode, it will not create + a new recording, but replace the recording that was attemted to be + replayed. The number of recordings can also be limited with + the ``max_cache_size`` argument. + """ + return self.frozen.n_cached_recordings + + def clear(self): + """ + Clears the recordings of the frozen function, and resets the + ``n_recordings`` counter. The reference to the function is still + kept, and the frozen function can be called again to re-trace + new recordings. + """ + return self.frozen.clear() + + def __get__(self, obj, type=None): + if obj is None: + return self + else: + return FrozenMethod(self.frozen, self.closure, obj) + + class FrozenMethod(FrozenFunction): + """ + A FrozenMethod currying the object into the __call__ method. + + If the ``freeze`` decorator is applied to a method of some class, it has + to call the internal frozen function with the ``self`` argument. To this + end we implement the ``__get__`` method of the frozen function, to + return a ``FrozenMethod``, which holds a reference to the object. + The ``__call__`` method of the ``FrozenMethod`` then supplies the object + in addition to the arguments to the internal function. + """ + def __init__(self, frozen, closure, obj) -> None: + self.obj = obj + self.frozen = frozen + self.closure = closure + + def __call__(self, *args, **kwargs): + _state = state_fn(self.obj, *args, **kwargs) if state_fn is not None else None + return self.frozen([self.closure, _state], self.obj, *args, **kwargs) + + return functools.wraps(f)(FrozenFunction(f)) + + if f is not None: + return decorator(f) + else: + return decorator + + +del F +del F2 def assert_true( cond, diff --git a/drjit/opt.py b/drjit/opt.py index 67b174cf..39a0bd0d 100644 --- a/drjit/opt.py +++ b/drjit/opt.py @@ -126,6 +126,11 @@ class Optimizer(Generic[Extra], MutableMapping[str, dr.ArrayBase]): # - an arbitrary sequence of additional optimizer-dependent state values state: Dict[str, Tuple[dr.ArrayBase, Optional[LearningRate], Extra]] + DRJIT_STRUCT = { + "lr": LearningRate, + "state": dict, + } + def __init__( self, lr: LearningRate, @@ -960,10 +965,15 @@ def _step( # Compute the step size scale, which is a product of # - EMA debiasing factor # - Adaptive/parameter-specific scaling + Float32 = dr.float32_array_t(dr.leaf_t(grad)) + Float64 = dr.float64_array_t(dr.leaf_t(grad)) + ema_factor = Float32( + -dr.sqrt(1 - Float64(self.beta_2) ** t) / (1 - Float64(self.beta_1) ** t) + ) scale = cache.product( dr.leaf_t(grad), # Desired type lr, - -dr.sqrt(1 - self.beta_2**t) / (1 - self.beta_1**t), + ema_factor, ) # Optional: use maximum of second order term @@ -981,9 +991,11 @@ def _step( def _reset(self, key: str, value: dr.ArrayBase, /) -> None: valarr = value.array tp = type(valarr) + UInt = dr.uint32_array_t(dr.leaf_t(tp)) + t = UInt(0) m_t = dr.opaque(tp, 0, valarr.shape) v_t = dr.opaque(tp, 0, valarr.shape) - self.state[key] = value, None, (0, m_t, v_t) + self.state[key] = value, None, (t, m_t, v_t) # Blend between the old and new versions of the optimizer extra state def _select( diff --git a/ext/drjit-core b/ext/drjit-core index e63e186e..4b7c51d0 160000 --- a/ext/drjit-core +++ b/ext/drjit-core @@ -1 +1 @@ -Subproject commit e63e186eefc2647e00efc2e1aeae3c66d996714e +Subproject commit 4b7c51d0e9eb3fa622e1f65f6ef1a79bae51b96e diff --git a/include/drjit/array_traverse.h b/include/drjit/array_traverse.h index 8fee8681..6ae82a5f 100644 --- a/include/drjit/array_traverse.h +++ b/include/drjit/array_traverse.h @@ -15,6 +15,8 @@ #pragma once +#include + #define DRJIT_STRUCT_NODEF(Name, ...) \ Name(const Name &) = default; \ Name(Name &&) = default; \ @@ -140,6 +142,18 @@ namespace detail { using det_traverse_1_cb_rw = decltype(T(nullptr)->traverse_1_cb_rw(nullptr, nullptr)); + template + using det_get = decltype(std::declval().get()); + + template + using det_const_get = decltype(std::declval().get()); + + template + using det_begin = decltype(std::declval().begin()); + + template + using det_end = decltype(std::declval().end()); + inline drjit::string get_label(const char *s, size_t i) { auto skip = [](char c) { return c == ' ' || c == '\r' || c == '\n' || c == '\t' || c == ','; @@ -180,10 +194,17 @@ template auto labels(const T &v) { } template -void traverse_1_fn_ro(const Value &value, void *payload, void (*fn)(void *, uint64_t)) { - (void) payload; (void) fn; +void traverse_1_fn_ro(const Value &value, void *payload, + void (*fn)(void *, uint64_t, const char *, + const char *)) { + DRJIT_MARK_USED(payload); + DRJIT_MARK_USED(fn); if constexpr (is_jit_v && depth_v == 1) { - fn(payload, value.index_combined()); + if constexpr(Value::IsClass) + fn(payload, value.index_combined(), Value::CallSupport::Variant, + Value::CallSupport::Domain); + else + fn(payload, value.index_combined(), "", ""); } else if constexpr (is_traversable_v) { traverse_1(fields(value), [payload, fn](auto &x) { traverse_1_fn_ro(x, payload, fn); @@ -198,14 +219,34 @@ void traverse_1_fn_ro(const Value &value, void *payload, void (*fn)(void *, uint is_detected_v) { if (value) value->traverse_1_cb_ro(payload, fn); + + } else if constexpr (is_detected_v && + is_detected_v) { + for (auto elem : value) { + traverse_1_fn_ro(elem, payload, fn); + } + } else if constexpr (is_detected_v) { + const auto *tmp = value.get(); + traverse_1_fn_ro(tmp, payload, fn); + } else if constexpr (is_detected_v) { + value.traverse_1_cb_ro(payload, fn); } } template -void traverse_1_fn_rw(Value &value, void *payload, uint64_t (*fn)(void *, uint64_t)) { - (void) payload; (void) fn; +void traverse_1_fn_rw(Value &value, void *payload, + uint64_t (*fn)(void *, uint64_t, const char *, + const char *)) { + DRJIT_MARK_USED(payload); + DRJIT_MARK_USED(fn); if constexpr (is_jit_v && depth_v == 1) { - value = Value::borrow((typename Value::Index) fn(payload, value.index_combined())); + if constexpr(Value::IsClass) + value = Value::borrow((typename Value::Index) fn( + payload, value.index_combined(), Value::CallSupport::Variant, + Value::CallSupport::Domain)); + else + value = Value::borrow((typename Value::Index) fn( + payload, value.index_combined(), "", "")); } else if constexpr (is_traversable_v) { traverse_1(fields(value), [payload, fn](auto &x) { traverse_1_fn_rw(x, payload, fn); @@ -220,6 +261,16 @@ void traverse_1_fn_rw(Value &value, void *payload, uint64_t (*fn)(void *, uint64 is_detected_v) { if (value) value->traverse_1_cb_rw(payload, fn); + } else if constexpr (is_detected_v && + is_detected_v) { + for (auto elem : value) { + traverse_1_fn_rw(elem, payload, fn); + } + } else if constexpr (is_detected_v) { + auto *tmp = value.get(); + traverse_1_fn_rw(tmp, payload, fn); + } else if constexpr (is_detected_v) { + value.traverse_1_cb_rw(payload, fn); } } diff --git a/include/drjit/autodiff.h b/include/drjit/autodiff.h index ad22d630..6049ca42 100644 --- a/include/drjit/autodiff.h +++ b/include/drjit/autodiff.h @@ -973,7 +973,9 @@ NAMESPACE_BEGIN(detail) /// Internal operations for traversing nested data structures and fetching or /// storing indices. Used in ``call.h`` and ``loop.h``. -template void collect_indices_fn(void *p, uint64_t index) { +template +void collect_indices_fn(void *p, uint64_t index, const char * /*variant*/, + const char * /*domain*/) { vector &indices = *(vector *) p; if constexpr (IncRef) index = ad_var_inc_ref(index); @@ -985,7 +987,8 @@ struct update_indices_payload { size_t &pos; }; -inline uint64_t update_indices_fn(void *p, uint64_t) { +inline uint64_t update_indices_fn(void *p, uint64_t, const char * /*variant*/, + const char * /*domain*/) { update_indices_payload &payload = *(update_indices_payload *) p; return payload.indices[payload.pos++]; } diff --git a/include/drjit/custom.h b/include/drjit/custom.h index 7a38ec3d..143a2d7d 100644 --- a/include/drjit/custom.h +++ b/include/drjit/custom.h @@ -20,6 +20,7 @@ #include #include +#include NAMESPACE_BEGIN(drjit) NAMESPACE_BEGIN(detail) diff --git a/include/drjit/extra.h b/include/drjit/extra.h index d2f24dd0..9df259d2 100644 --- a/include/drjit/extra.h +++ b/include/drjit/extra.h @@ -250,6 +250,11 @@ extern DRJIT_EXTRA_EXPORT bool ad_release_one_output(drjit::detail::CustomOpBase extern DRJIT_EXTRA_EXPORT void ad_copy_implicit_deps(drjit::vector &, bool input); +/// Retrieve a list of ad indices, that are the target of edges, that have been +/// postponed by the current scope +extern DRJIT_EXTRA_EXPORT void ad_scope_postponed(drjit::vector *dst); + + /// Kahan-compensated floating point atomic scatter-addition extern DRJIT_EXTRA_EXPORT void ad_var_scatter_add_kahan(uint64_t *target_1, uint64_t *target_2, uint64_t value, diff --git a/include/drjit/python.h b/include/drjit/python.h index 29229bda..7e5d9c33 100644 --- a/include/drjit/python.h +++ b/include/drjit/python.h @@ -54,6 +54,7 @@ #include #include #include +#include NAMESPACE_BEGIN(drjit) struct ArrayBinding; @@ -1057,25 +1058,73 @@ template void bind_all(ArrayBinding &b) { // Expose already existing object tree traversal callbacks (T::traverse_1_..) in Python. // This functionality is needed to traverse custom/opaque C++ classes and correctly // update their members when they are used in vectorized loops, function calls, etc. -template auto& bind_traverse(nanobind::class_ &cls) { +template auto &bind_traverse(nanobind::class_ &cls) +{ namespace nb = nanobind; - struct Payload { nb::callable c; }; + struct Payload { + nb::callable c; + }; + + static_assert(std::is_base_of_v); cls.def("_traverse_1_cb_ro", [](const T *self, nb::callable c) { Payload payload{ std::move(c) }; - self->traverse_1_cb_ro((void *) &payload, [](void *p, uint64_t index) { - ((Payload *) p)->c(index); - }); + self->traverse_1_cb_ro((void *) &payload, + [](void *p, uint64_t index, const char *variant, const char *domain) { + ((Payload *) p)->c(index, variant, domain); + }); }); cls.def("_traverse_1_cb_rw", [](T *self, nb::callable c) { Payload payload{ std::move(c) }; - self->traverse_1_cb_rw((void *) &payload, [](void *p, uint64_t index) { - return nb::cast(((Payload *) p)->c(index)); + self->traverse_1_cb_rw((void *) &payload, [](void *p, uint64_t index, + const char *variant, + const char *domain) { + return nb::cast( + ((Payload *) p)->c(index, variant, domain)); }); }); return cls; } +inline void traverse_py_cb_ro(const TraversableBase *base, void *payload, + void (*fn)(void *, uint64_t, const char *variant, + const char *domain)) { + namespace nb = nanobind; + nb::handle self = base->self_py(); + if (!self) + return; + + auto detail = nb::module_::import_("drjit.detail"); + nb::callable traverse_py_cb_ro_fn = + nb::borrow(nb::getattr(detail, "traverse_py_cb_ro")); + + traverse_py_cb_ro_fn(self, + nb::cpp_function([&](uint64_t index, const char *variant, + const char *domain) { + fn(payload, index, variant, domain); + })); +} + +inline void traverse_py_cb_rw(TraversableBase *base, void *payload, + uint64_t (*fn)(void *, uint64_t, const char *, + const char *)) { + + namespace nb = nanobind; + nb::handle self = base->self_py(); + if (!self) + return; + + auto detail = nb::module_::import_("drjit.detail"); + nb::callable traverse_py_cb_rw_fn = + nb::borrow(nb::getattr(detail, "traverse_py_cb_rw")); + + traverse_py_cb_rw_fn(self, + nb::cpp_function([&](uint64_t index, const char *variant, + const char *domain) { + return fn(payload, index, variant, domain); + })); +} + NAMESPACE_END(drjit) diff --git a/include/drjit/texture.h b/include/drjit/texture.h index ff6b6f3a..c67dcbdd 100644 --- a/include/drjit/texture.h +++ b/include/drjit/texture.h @@ -18,6 +18,8 @@ #include #include #include +#include +#include #pragma once @@ -42,7 +44,7 @@ enum class CudaTextureFormat : uint32_t { Float16 = 1, /// Half precision storage format }; -template class Texture { +template class Texture : TraversableBase { public: static constexpr bool IsCUDA = is_cuda_v; static constexpr bool IsDiff = is_diff_v; @@ -1591,6 +1593,48 @@ template class Texture { mutable bool m_tensor_dirty = false; /* Flag to indicate whether public-facing unpadded tensor needs to be updated */ + +public: + void + traverse_1_cb_ro(void *payload, + drjit ::detail ::traverse_callback_ro fn) const override { + // Only traverse the texture for frozen functions, since accidentally + // traversing the scene in loops or vcalls can cause issues. + if (!jit_flag(JitFlag::EnableObjectTraversal)) + return; + + DRJIT_MAP(DR_TRAVERSE_MEMBER_RO, m_value, m_unpadded_value, + m_resolution_opaque, m_inv_resolution); + if constexpr (HasCudaTexture) { + uint32_t n_textures = 1 + ((m_channels - 1) / 4); + std::vector indices(n_textures); + jit_cuda_tex_get_indices(m_handle, indices.data()); + for (uint32_t i = 0; i < n_textures; i++) { + fn(payload, indices[i], "", ""); + } + } + } + void traverse_1_cb_rw(void *payload, + drjit ::detail ::traverse_callback_rw fn) override { + // Only traverse the texture for frozen functions, since accidentally + // traversing the scene in loops or vcalls can cause issues. + if (!jit_flag(JitFlag::EnableObjectTraversal)) + return; + + DRJIT_MAP(DR_TRAVERSE_MEMBER_RW, m_value, m_unpadded_value, + m_resolution_opaque, m_inv_resolution); + if constexpr (HasCudaTexture) { + uint32_t n_textures = 1 + ((m_channels - 1) / 4); + std::vector indices(n_textures); + jit_cuda_tex_get_indices(m_handle, indices.data()); + for (uint32_t i = 0; i < n_textures; i++) { + uint64_t new_index = fn(payload, indices[i], "", ""); + if (new_index != indices[i]) + jit_raise("A texture was changed by traversing it. This is " + "not supported!"); + } + } + } }; NAMESPACE_END(drjit) diff --git a/include/drjit/traversable_base.h b/include/drjit/traversable_base.h new file mode 100644 index 00000000..3e956365 --- /dev/null +++ b/include/drjit/traversable_base.h @@ -0,0 +1,254 @@ +#pragma once + +#include "fwd.h" +#include +#include +#include +#include +#include +#include + +NAMESPACE_BEGIN(drjit) + +NAMESPACE_BEGIN(detail) +/** + * \brief The callback used to traverse all jit arrays of a C++ object such as a + * Mitsuba scene. + * + * \param payload: + * To wrap closures, a payload can be provided to the ``traverse_1_cb_ro`` + * function, that is passed to the callback. + * + * \param index: + * A non-owning index of the traversed jit array. + * + * \param variant: + * If a ``JitArray`` has the attribute ``IsClass`` it is referring to a + * drjit class. When such a variable is traversed, the ``variant`` and + * ``domain`` string of its ``CallSupport`` is provided to the callback + * using this argument. Otherwise the string is equal to "". + * + * \param domain: + * The domain of the ``CallSupport`` when traversing a class variable. + */ +using traverse_callback_ro = void (*)(void *payload, uint64_t index, + const char *variant, const char *domain); +/** + * \brief The callback used to traverse and modify all jit arrays of a C++ + * object such as a Mitsuba scene. + * + * \param payload: + * To wrap closures, a payload can be provided to the ``traverse_1_cb_ro`` + * function, that is passed to the callback. + * + * \param index: + * A non-owning index of the traversed jit array. + * + * \param variant: + * If a ``JitArray`` has the attribute ``IsClass`` it is referring to a + * drjit class. When such a variable is traversed, the ``variant`` and + * ``domain`` string of its ``CallSupport`` is provided to the callback + * using this argument. Otherwise the string is equal to "". + * + * \param domain: + * The domain of the ``CallSupport`` when traversing a class variable. + * + * \return + * The new index of the traversed variable. This index is borrowed, and + * should therefore be non-owning. + */ +using traverse_callback_rw = uint64_t (*)(void *payload, uint64_t index, + const char *variant, + const char *domain); + +inline void log_member_open(bool rw, const char *member) { + jit_log(LogLevel::Debug, "%s%s{", rw ? "rw " : "ro ", member); +} + +inline void log_member_close() { jit_log(LogLevel::Debug, "}"); } + +NAMESPACE_END(detail) + +/** + * \brief Interface for traversing C++ objects. + * + * This interface should be inherited by any class that can be added to the + * registry. We try to ensure this by wrapping the function ``jit_registry_put`` + * in the function ``drjit::registry_put`` that takes a ``TraversableBase`` for + * the pointer argument. + */ +struct DRJIT_EXTRA_EXPORT TraversableBase : public nanobind::intrusive_base { + /** + * \brief Traverse all jit arrays in this c++ object. For every jit + * variable, the callback should be called, with the provided payload + * pointer. + * + * \param payload: + * A pointer to a payload struct. The callback ``cb`` is called with this + * pointer. + * + * \param cb: + * A function pointer, that is called with the ``payload`` pointer, the + * index of the jit variable, and optionally the domain and variant of a + * ``Class`` variable. + */ + virtual void traverse_1_cb_ro(void *payload, + detail::traverse_callback_ro cb) const = 0; + + /** + * \brief Traverse all jit arrays in this c++ object, and assign the output of the + * callback to them. For every jit variable, the callback should be called, + * with the provided payload pointer. + * + * \param payload: + * A pointer to a payload struct. The callback ``cb`` is called with this + * pointer. + * + * \param cb: + * A function pointer, that is called with the ``payload`` pointer, the + * index of the jit variable, and optionally the domain and variant of a + * ``Class`` variable. The resulting index of calling this function + * pointer will be assigned to the traversed variable. The return value + * of the is borrowed from when overwriting assigning the traversed + * variable. + */ + virtual void traverse_1_cb_rw(void *payload, + detail::traverse_callback_rw cb) = 0; +}; + +/** + * \brief Macro for generating call to \c traverse_1_fn_ro for a class member. + * + * This is only a utility macro, for the DR_TRAVERSE_CB_RO macro. It can only be + * used in a context, where the ``payload`` and ``fn`` variables are present. + */ +#define DR_TRAVERSE_MEMBER_RO(member) \ + drjit::detail::log_member_open(false, #member); \ + drjit::traverse_1_fn_ro(member, payload, fn); \ + drjit::detail::log_member_close(); + +/** + * \brief Macro for generating call to \c traverse_1_fn_rw for a class member. + * + * This is only a utility macro, for the DR_TRAVERSE_CB_RW macro. It can only be + * used in a context, where the ``payload`` and ``fn`` variables are present. + */ +#define DR_TRAVERSE_MEMBER_RW(member) \ + drjit::detail::log_member_open(true, #member); \ + drjit::traverse_1_fn_rw(member, payload, fn); \ + drjit::detail::log_member_close(); + +/** + * \brief Macro, generating the implementation of the ``traverse_1_cb_ro`` + * method. + * + * The first argument should be the base class, from which the current class + * inherits. The other arguments should list all members of that class, which + * are supposed to be read only traversable. + */ +#define DR_TRAVERSE_CB_RO(Base, ...) \ + void traverse_1_cb_ro(void *payload, \ + drjit::detail::traverse_callback_ro fn) \ + const override { \ + static_assert( \ + std::is_base_of>::value); \ + if constexpr (!std::is_same_v) \ + Base::traverse_1_cb_ro(payload, fn); \ + DRJIT_MAP(DR_TRAVERSE_MEMBER_RO, __VA_ARGS__) \ + } + +/** + * \breif Macro, generating the implementation of the ``traverse_1_cb_rw`` + * method. + * + * The first argument should be the base class, from which the current class + * inherits. The other arguments should list all members of that class, which + * are supposed to be read and write traversable. + */ +#define DR_TRAVERSE_CB_RW(Base, ...) \ + void traverse_1_cb_rw(void *payload, \ + drjit::detail::traverse_callback_rw fn) override { \ + static_assert( \ + std::is_base_of>::value); \ + if constexpr (!std::is_same_v) \ + Base::traverse_1_cb_rw(payload, fn); \ + DRJIT_MAP(DR_TRAVERSE_MEMBER_RW, __VA_ARGS__) \ + } + +/** + * \brief Macro, generating the both the implementations of the + * ``traverse_1_cb_ro`` and ``traverse_1_cb_rw`` methods. + * + * The first argument should be the base class, from which the current class + * inherits. The other arguments should list all members of that class, which + * are supposed to be read and write traversable. + */ +#define DR_TRAVERSE_CB(Base, ...) \ +public: \ + DR_TRAVERSE_CB_RO(Base, __VA_ARGS__) \ + DR_TRAVERSE_CB_RW(Base, __VA_ARGS__) + +/** + * \brief Macro, generating the implementations of ``traverse_1_cb_ro`` and + * ``traverse_1_cb_rw`` of a nanobind trampoline class. + * + * This macro should only be instantiated on trampoline classes, that serve as + * the base class for derived types in Python. Adding this macro to a trampoline + * class, allows for the automatic traversal of all python members in any + * derived python class. + */ +#define DR_TRAMPOLINE_TRAVERSE_CB(Base) \ +public: \ + void traverse_1_cb_ro(void *payload, \ + drjit::detail::traverse_callback_ro fn) \ + const override { \ + DRJIT_MARK_USED(payload); \ + DRJIT_MARK_USED(fn); \ + if constexpr (!std ::is_same_v) \ + Base::traverse_1_cb_ro(payload, fn); \ + drjit::traverse_py_cb_ro(this, payload, fn); \ + } \ + void traverse_1_cb_rw(void *payload, \ + drjit::detail::traverse_callback_rw fn) override { \ + DRJIT_MARK_USED(payload); \ + DRJIT_MARK_USED(fn); \ + if constexpr (!std ::is_same_v) \ + Base::traverse_1_cb_rw(payload, fn); \ + drjit::traverse_py_cb_rw(this, payload, fn); \ + } + +/** + * \brief Register a \c TraversableBase pointer with Dr.Jit's pointer registry + * + * This should be used instead of \c jit_registry_put, as it enforces the + * pointers to be of type \c TraversableBase, allowing for traversal of registry + * bound pointers. + * + * Dr.Jit provides a central registry that maps registered pointer values to + * low-valued 32-bit IDs. The main application is efficient virtual function + * dispatch via \ref jit_var_call(), through the registry could be used for + * other applications as well. + * + * This function registers the specified pointer \c ptr with the registry, + * returning the associated ID value, which is guaranteed to be unique within + * the specified domain identified by the \c (variant, domain) strings. + * The domain is normally an identifier that is associated with the "flavor" + * of the pointer (e.g. instances of a particular class), and which ensures + * that the returned ID values are as low as possible. + * + * Caution: for reasons of efficiency, the \c domain parameter is assumed to a + * static constant that will remain alive. The RTTI identifier + * typeid(MyClass).name() is a reasonable choice that satisfies this + * requirement. + * + * Raises an exception when ``ptr`` is ``nullptr``, or when it has already been + * registered with *any* domain. + */ +inline uint32_t registry_put(const char *variant, const char *domain, + TraversableBase *ptr) { + return jit_registry_put(variant, domain, (void *) ptr); +} + +NAMESPACE_END(drjit) diff --git a/src/extra/autodiff.cpp b/src/extra/autodiff.cpp index 11c5fda2..9ae57f09 100644 --- a/src/extra/autodiff.cpp +++ b/src/extra/autodiff.cpp @@ -43,6 +43,7 @@ */ #include "common.h" +#include #include #include #include @@ -1783,6 +1784,18 @@ void ad_scope_leave(bool process_postponed) { } } +void ad_scope_postponed(drjit::vector *dst) { + LocalState &ls = local_state; + std::vector &scopes = ls.scopes; + if (scopes.empty()) + ad_raise("ad_scope_leave(): scope underflow!"); + Scope &scope = scopes.back(); + + for (auto &er : scope.postponed) { + dst->push_back(er.target); + } +} + /// Check if gradient tracking is enabled for the given variable int ad_grad_enabled(Index index) { ADIndex ad_index = ::ad_index(index); diff --git a/src/python/CMakeLists.txt b/src/python/CMakeLists.txt index 9886323d..5d97ee40 100644 --- a/src/python/CMakeLists.txt +++ b/src/python/CMakeLists.txt @@ -68,6 +68,7 @@ nanobind_add_module( reduce.h reduce.cpp apply.h apply.cpp eval.h eval.cpp + freeze.h freeze.cpp memop.h memop.cpp slice.h slice.cpp dlpack.h dlpack.cpp diff --git a/src/python/apply.cpp b/src/python/apply.cpp index c611e099..ce05011d 100644 --- a/src/python/apply.cpp +++ b/src/python/apply.cpp @@ -603,11 +603,12 @@ struct recursion_guard { ~recursion_guard() { recursion_level--; } }; -void TraverseCallback::operator()(uint64_t) { } +uint64_t TraverseCallback::operator()(uint64_t, const char *, const char *) { return 0; } void TraverseCallback::traverse_unknown(nb::handle) { } /// Invoke the given callback on leaf elements of the pytree 'h' -void traverse(const char *op, TraverseCallback &tc, nb::handle h) { +void traverse(const char *op, TraverseCallback &tc, nb::handle h, + bool rw) { nb::handle tp = h.type(); recursion_guard guard; @@ -622,30 +623,64 @@ void traverse(const char *op, TraverseCallback &tc, nb::handle h) { len = s.len(inst_ptr(h)); for (Py_ssize_t i = 0; i < len; ++i) - traverse(op, tc, nb::steal(s.item(h.ptr(), i))); + traverse(op, tc, nb::steal(s.item(h.ptr(), i)), rw); } else { tc(h); } } else if (tp.is(&PyTuple_Type)) { for (nb::handle h2 : nb::borrow(h)) - traverse(op, tc, h2); + traverse(op, tc, h2, rw); } else if (tp.is(&PyList_Type)) { for (nb::handle h2 : nb::borrow(h)) - traverse(op, tc, h2); + traverse(op, tc, h2, rw); } else if (tp.is(&PyDict_Type)) { for (nb::handle h2 : nb::borrow(h).values()) - traverse(op, tc, h2); + traverse(op, tc, h2, rw); } else { if (nb::dict ds = get_drjit_struct(tp); ds.is_valid()) { for (auto [k, v] : ds) - traverse(op, tc, nb::getattr(h, k)); + traverse(op, tc, nb::getattr(h, k), rw); } else if (nb::object df = get_dataclass_fields(tp); df.is_valid()) { for (nb::handle field : df) { nb::object k = field.attr(DR_STR(name)); - traverse(op, tc, nb::getattr(h, k)); + traverse(op, tc, nb::getattr(h, k), rw); } - } else if (nb::object cb = get_traverse_cb_ro(tp); cb.is_valid()) { - cb(h, nb::cpp_function([&](uint64_t index) { tc(index); })); + } else if (auto traversable = get_traversable_base(h); + traversable) { + struct Payload { + TraverseCallback &tc; + }; + Payload p{ tc }; + if (rw) { + traversable->traverse_1_cb_rw( + (void *) &p, + [](void *p, uint64_t index, const char *variant, + const char *domain) -> uint64_t { + Payload *payload = (Payload *) p; + uint64_t new_index = + payload->tc(index, variant, domain); + return new_index; + }); + } else { + traversable->traverse_1_cb_ro( + (void *) &p, + [](void *p, uint64_t index, const char *variant, + const char *domain) { + Payload *payload = (Payload *) p; + payload->tc(index, variant, domain); + }); + } + } else if (auto cb = get_traverse_cb_ro(tp); cb.is_valid() && !rw) { + cb(h, nb::cpp_function([&](uint64_t index, const char *variant, + const char *domain) { + tc(index, variant, domain); + })); + } else if (nb::object cb = get_traverse_cb_rw(tp); + cb.is_valid() && rw) { + cb(h, nb::cpp_function([&](uint64_t index, const char *variant, + const char *domain) { + return tc(index, variant, domain); + })); } else { tc.traverse_unknown(h); } @@ -903,7 +938,9 @@ nb::object transform(const char *op, TransformCallback &tc, nb::handle h) { } result = tp(**tmp); } else if (nb::object cb = get_traverse_cb_rw(tp); cb.is_valid()) { - cb(h, nb::cpp_function([&](uint64_t index) { return tc(index); })); + cb(h, + nb::cpp_function([&](uint64_t index, const char *, + const char *) { return tc(index); })); result = nb::borrow(h); } else if (!result.is_valid()) { result = tc.transform_unknown(h); diff --git a/src/python/apply.h b/src/python/apply.h index 8e67ee08..c3028964 100644 --- a/src/python/apply.h +++ b/src/python/apply.h @@ -57,7 +57,8 @@ struct TraverseCallback { // Type-erased form which is needed in some cases to traverse into opaque // C++ code. This one just gets called with Jit/AD variable indices, an // associated Python/ instance/type is not available. - virtual void operator()(uint64_t index); + virtual uint64_t operator()(uint64_t index, const char *variant = nullptr, + const char *domain = nullptr); // Traverse an unknown object virtual void traverse_unknown(nb::handle h); @@ -80,9 +81,16 @@ struct TransformCallback { /// Initialize 'h2' (already allocated) based on 'h1' virtual void operator()(nb::handle h1, nb::handle h2) = 0; - // Type-erased form which is needed in some cases to traverse into opaque - // C++ code. This one just gets called with Jit/AD variable indices, an - // associated Python/ instance/type is not available. + /** Type-erased form which is needed in some cases to traverse into opaque + * C++ code. This one just gets called with Jit/AD variable indices, an + * associated Python/ instance/type is not available. + * This can optionally return a non-owning jit_index, that will be assigned + * to the underlying variable if \c traverse is called with the \c rw + * argument set to \c true. This can be used to modify JIT variables of + * PyTrees and their C++ objects in-place. For example, when applying + * operations such as \c jit_var_schedule_force to every JIT variable in a + * PyTree. + */ virtual uint64_t operator()(uint64_t index); }; @@ -96,9 +104,27 @@ struct TransformPairCallback { virtual nb::object transform_unknown(nb::handle h1, nb::handle h2) const; }; -/// Invoke the given callback on leaf elements of the pytree 'h' -extern void traverse(const char *op, TraverseCallback &callback, - nb::handle h); +/** + * \brief Invoke the given callback on leaf elements of the pytree 'h', + * including JIT indices in c++ objects, inheriting from + * \c drjit::TraversableBase. + * + * \param op: + * Name of the operation that is performed, this will be used in the + * exceptions that might be raised during traversal. + * + * \param callback: + * The \c TraverseCallback, called for every Jit variable in the pytree. + * + * \param rw: + * Boolean, indicating if C++ objects should be traversed in read-write + * mode. If this is set to \c true, the result from the method + * \c operator()(uint64_t) of the callback will be assigned to the + * underlying variable. This does not change how Python objects are + * traversed. + */ +extern void traverse(const char *op, TraverseCallback &callback, nb::handle h, + bool rw = false); /// Parallel traversal of two compatible pytrees 'h1' and 'h2' extern void traverse_pair(const char *op, TraversePairCallback &callback, diff --git a/src/python/common.h b/src/python/common.h index 215e840e..07fc5c85 100644 --- a/src/python/common.h +++ b/src/python/common.h @@ -88,6 +88,13 @@ inline nb::object get_dataclass_fields(nb::handle tp) { } return result; } +/// Return a pointer to the underlying C++ class if the Python object inherits +/// from TraversableBase or null otherwise +inline drjit::TraversableBase *get_traversable_base(nb::handle h) { + drjit::TraversableBase *result = nullptr; + nb::try_cast(h, result); + return result; +} /// Extract a read-only callback to traverse custom data structures inline nb::object get_traverse_cb_ro(nb::handle tp) { diff --git a/src/python/detail.cpp b/src/python/detail.cpp index ef9f27e9..9e7a90d9 100644 --- a/src/python/detail.cpp +++ b/src/python/detail.cpp @@ -111,13 +111,15 @@ void collect_indices(nb::handle h, dr::vector &indices, bool inc_ref) void operator()(nb::handle h) override { auto index_fn = supp(h.type()).index; if (index_fn) - operator()(index_fn(inst_ptr(h))); + operator()(index_fn(inst_ptr(h)), nullptr, nullptr); } - void operator()(uint64_t index) override { + uint64_t operator()(uint64_t index, const char *, + const char *) override { if (inc_ref) ad_var_inc_ref(index); result.push_back(index); + return 0; } }; @@ -281,6 +283,115 @@ bool leak_warnings() { return nb::leak_warnings() || jit_leak_warnings() || ad_leak_warnings(); } +// Have to wrap this in an unnamed namespace to prevent collisions with the +// other declaration of ``recursion_guard``. +namespace { +static int recursion_level = 0; + +// PyTrees could theoretically include cycles. Catch infinite recursion below +struct recursion_guard { + recursion_guard() { + if (++recursion_level >= 50) { + PyErr_SetString(PyExc_RecursionError, "runaway recursion detected"); + nb::raise_python_error(); + } + } + ~recursion_guard() { recursion_level--; } +}; +} // namespace + +/** + * \brief Traverses all variables of a python object. + * + * This function is used to traverse variables of python objects, inheriting + * from trampoline classes. This allows the user to freeze for example custom + * BSDFs, without having to declare its variables. + */ +void traverse_py_cb_ro_impl(nb::handle self, nb::callable c) { + recursion_guard guard; + + struct PyTraverseCallback : TraverseCallback { + void operator()(nb::handle h) override { + const ArraySupplement &s = supp(h.type()); + auto index_fn = s.index; + if (index_fn){ + if (s.is_class){ + auto variant = + nb::borrow(nb::getattr(h, "Variant")); + auto domain = + nb::borrow(nb::getattr(h, "Domain")); + operator()(index_fn(inst_ptr(h)), + variant.is_valid() ? variant.c_str() : "", + domain.is_valid() ? domain.c_str() : ""); + } + else + operator()(index_fn(inst_ptr(h)), "", ""); + } + } + uint64_t operator()(uint64_t index, const char *variant, + const char *domain) override { + m_callback(index, variant, domain); + return 0; + } + nb::callable m_callback; + + PyTraverseCallback(nb::callable c) : m_callback(c) {} + }; + + PyTraverseCallback traverse_cb(std::move(c)); + + auto dict = nb::borrow(nb::getattr(self, "__dict__")); + + for (auto value : dict.values()) { + traverse("traverse_py_cb_ro", traverse_cb, value); + } +} + +/** + * \brief Traverses all variables of a python object. + * + * This function is used to traverse variables of python objects, inheriting + * from trampoline classes. This allows the user to freeze for example custom + * BSDFs, without having to declare its variables. + */ +void traverse_py_cb_rw_impl(nb::handle self, nb::callable c) { + recursion_guard guard; + + struct PyTraverseCallback : TraverseCallback { + void operator()(nb::handle h) override { + const ArraySupplement &s = supp(h.type()); + auto index_fn = s.index; + if (index_fn){ + uint64_t new_index; + if (s.is_class) { + auto variant = + nb::borrow(nb::getattr(h, "Variant")); + auto domain = nb::borrow(nb::getattr(h, "Domain")); + new_index = operator()( + index_fn(inst_ptr(h)), + variant.is_valid() ? variant.c_str() : "", + domain.is_valid() ? domain.c_str() : ""); + } else + new_index = operator()(index_fn(inst_ptr(h)), "", ""); + s.reset_index(new_index, inst_ptr(h)); + } + } + uint64_t operator()(uint64_t index, const char *variant, const char *domain) override { + return nb::cast(m_callback(index, variant, domain)); + } + nb::callable m_callback; + + PyTraverseCallback(nb::callable c) : m_callback(c) {} + }; + + PyTraverseCallback traverse_cb(std::move(c)); + + auto dict = nb::borrow(nb::getattr(self, "__dict__")); + + for (auto value : dict.values()) { + traverse("traverse_py_cb_rw", traverse_cb, value, true); + } +} void export_detail(nb::module_ &) { nb::module_ d = nb::module_::import_("drjit.detail"); @@ -344,6 +455,8 @@ void export_detail(nb::module_ &) { d.def("leak_warnings", &leak_warnings, doc_leak_warnings); d.def("set_leak_warnings", &set_leak_warnings, doc_set_leak_warnings); + d.def("traverse_py_cb_ro", &traverse_py_cb_ro_impl); + d.def("traverse_py_cb_rw", traverse_py_cb_rw_impl); trace_func_handle = d.attr("trace_func"); } diff --git a/src/python/docstr.rst b/src/python/docstr.rst index 0046499e..87b2e978 100644 --- a/src/python/docstr.rst +++ b/src/python/docstr.rst @@ -6040,6 +6040,35 @@ Note that this information can also be queried in a more fine-grained manner (per variable) using the :py:attr:`drjit.ArrayBase.state` field. +.. topic:: JitFlag_KernelFreezing + + Enable recording and replay of functions annotated with :py:func:`freeze`. + + If KernelFreezing is enabled, all Dr.Jit operations executed in a function + annotated with :py:func:`freeze` are recorded during its first execution + and replayed without re-tracing on subsequent calls. + + If this flag is disabled, replay of previously frozen functions is disabled + as well. + +.. topic:: JitFlag_FreezingScope + + This flag is set to ``True`` when Dr.Jit is currently recording a frozen + function. The flag is automatically managed and should not be updated by + application code. + + User code may query this flag to conditionally optimize kernels for frozen + function recording, such as re-seeding the sampler, used for rendering. + +.. topic:: JitFlag_EnableObjectTraversal + + This flag is set to ``True`` when Dr.Jit is currently traversing + inputs and outputs of a frozen function. The flag is automatically managed + and should not be updated by application code. + + When enabled, traversal of objects such as the ``Scene`` or ``BSDFs`` in + mitsuba is enabled. + .. topic:: JitFlag_Default The default set of optimization flags consisting of diff --git a/src/python/freeze.cpp b/src/python/freeze.cpp new file mode 100644 index 00000000..421fb6ea --- /dev/null +++ b/src/python/freeze.cpp @@ -0,0 +1,1551 @@ +#include "freeze.h" +#include "apply.h" +#include "autodiff.h" +#include "base.h" +#include "common.h" +#include "reduce.h" +#include "listobject.h" +#include "object.h" +#include "pyerrors.h" +#include "shape.h" +#include "tupleobject.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +/** + * \brief Helper struct to profile and log frozen functions. + */ +struct ProfilerPhase { + std::string m_message; + ProfilerPhase(const char *message) : m_message(message) { + jit_log(LogLevel::Debug, "profiler start: %s", message); +#if defined(DRJIT_ENABLE_NVTX) + jit_profile_range_push(message); +#endif + } + + ProfilerPhase(const drjit::TraversableBase *traversable) { + int status; + char message[1024] = {0}; + const char *name = typeid(*traversable).name(); + snprintf(message, 1024, "traverse_cb %s", name); + + jit_log(LogLevel::Debug, "profiler start: %s", message); + jit_profile_range_push(message); + m_message = message; + } + + ~ProfilerPhase() { +#if defined(DRJIT_ENABLE_NVTX) + jit_profile_range_pop(); +#endif + jit_log(LogLevel::Debug, "profiler end: %s", m_message.c_str()); + } +}; + +struct ADScopeContext { + bool process_postponed; + ADScopeContext(drjit::ADScope type, size_t size, const uint64_t *indices, + int symbolic, bool process_postponed) + : process_postponed(process_postponed) { + ad_scope_enter(type, size, indices, symbolic); + } + ~ADScopeContext() { ad_scope_leave(process_postponed); } +}; + +struct scoped_set_flag { + uint32_t backup; + scoped_set_flag(JitFlag flag, bool enabled) : backup(jit_flags()) { + uint32_t flags = backup; + if (enabled) + flags |= (uint32_t)flag; + else + flags &= ~(uint32_t) flag; + + jit_set_flags(flags); + } + + ~scoped_set_flag() { + jit_set_flags(backup); + } +}; + +using namespace detail; + +bool Layout::operator==(const Layout &rhs) const { + if (((bool) this->type != (bool) rhs.type) || !(this->type.equal(rhs.type))) + return false; + + if (this->num != rhs.num) + return false; + + if (this->fields.size() != rhs.fields.size()) + return false; + + for (uint32_t i = 0; i < this->fields.size(); ++i) { + if (!(this->fields[i].equal(rhs.fields[i]))) + return false; + } + + if (this->index != rhs.index) + return false; + + if (this->flags != rhs.flags) + return false; + + if (this->literal != rhs.literal) + return false; + + if (this->vt != rhs.vt) + return false; + + if (((bool) this->py_object != (bool) rhs.py_object) || + !this->py_object.equal(rhs.py_object)) + return false; + + return true; +} + +bool VarLayout::operator==(const VarLayout &rhs) const { + if (this->vt != rhs.vt) + return false; + + if (this->vs != rhs.vs) + return false; + + if (this->flags != rhs.flags) + return false; + + if (this->size_index != rhs.size_index) + return false; + + return true; +} + +/** + * \brief Add a variant domain pair to be traversed using the registry. + * + * When traversing a jit variable, that references a pointer to a class, + * such as a BSDF or Shape in Mitsuba, we have to traverse all objects + * registered with that variant-domain pair in the registry. This function + * adds the variant-domain pair, deduplicating the domain. Whether a + * variable references a class is represented by it's ``IsClass`` const + * attribute. If the domain is an empty string (""), this function skips + * adding the variant-domain pair. + */ +void FlatVariables::add_domain(const char *variant, const char *domain) { + // Since it is not possible to pass nullptr strings to nanobind functions we + // assume, that a valid domain indicates a valid variant. If the variant is + // empty at the end of traversal, we know that no Class variable was + // traversed, and registry traversal is not necessary. + if (domain && variant && strcmp(domain, "") != 0) { + jit_log(LogLevel::Debug, "variant=%s, domain=%s", variant, domain); + + if (domains.empty()){ + this->variant = variant; + } + else if (this->variant != variant) + jit_raise("traverse(): Variant missmatch! All arguments to a " + "frozen function have to have the same variant. " + "Variant %s of a previos argument does not match " + "variant %s of this argument.", + this->variant.c_str(), variant); + + bool contains = false; + for (std::string &d : domains) { + if (d == domain) { + contains = true; + break; + } + } + if (!contains) + domains.push_back(domain); + } +} + +/** + * Adds a jit index to the flattened array, deduplicating it. + * This allows to check for aliasing conditions, where two variables + * actually refer to the same index. The function should only be called for + * scheduled non-literal variable indices. + */ +uint32_t FlatVariables::add_jit_index(uint32_t index) { + uint32_t next_slot = this->variables.size(); + auto result = this->index_to_slot.try_emplace(index, next_slot); + auto it = result.first; + bool inserted = result.second; + + if (inserted) { + this->variables.push_back(index); + // Borrow the variable + jit_var_inc_ref(index); + this->var_layout.emplace_back(); + return next_slot; + } else { + return it.value(); + } +} + +/** + * \brief Records information about jit variables, that have been traversed. + * + * After traversing the PyTree, collecting non-literal indices in + * ``variables`` and evaluating the collected indices, we can collect + * information about the underlying variables. This information is used in + * the key of the ``RecordingMap`` to determine which recording should be + * replayed or if the function has to be re-traced. This function iterates + * over the collected indices and collects that information. + */ +void FlatVariables::record_jit_variables() { + assert(variables.size() == var_layout.size()); + for (uint32_t i = 0; i < var_layout.size(); i++){ + uint32_t index = variables[i]; + VarLayout &layout = var_layout[i]; + + VarInfo info = jit_set_backend(index); + + if (backend == info.backend || this->backend == JitBackend::None) { + backend = info.backend; + } else { + jit_raise("freeze(): backend missmatch error (backend of this " + "variable %s does not match backend of others %s)!", + info.backend == JitBackend::CUDA ? "CUDA" : "LLVM", + backend == JitBackend::CUDA ? "CUDA" : "LLVM"); + } + + if (info.type == VarType::Pointer) { + // We do not support pointers as inputs. It might be possible with + // some extra handling, but they are never used directly. + jit_raise("Pointer inputs not supported!"); + } + + layout.vs = info.state; + layout.vt = info.type; + layout.size_index = this->add_size(info.size); + + if (info.state == VarState::Evaluated) { + // Special case, handling evaluated/opaque variables. + + layout.flags |= + (info.size == 1 ? (uint32_t) LayoutFlag::SingletonArray : 0); + layout.flags |= (info.unaligned ? (uint32_t) LayoutFlag::Unaligned : 0); + + } else { + jit_raise("collect(): found variable %u in unsupported state %u!", + index, (uint32_t) info.state); + } + } +} + +/** + * This function returns an index of an equivalence class for the variable + * size in the flattened variables. + * It uses a hashmap and vector to deduplicate sizes. + * + * This is necessary, to catch cases, where two variables had the same size + * when freezing a function and two different sizes when replaying. + * In that case one kernel would be recorded, that evaluates both variables. + * However, when replaying two kernels would have to be launched since the + * now differently sized variables cannot be evaluated by the same kernel. + */ +uint32_t FlatVariables::add_size(uint32_t size) { + uint32_t next_slot = this->sizes.size(); + auto result = this->size_to_slot.try_emplace(size, next_slot); + auto it = result.first; + bool inserted = result.second; + + if (inserted) { + this->sizes.push_back(size); + return next_slot; + } else { + return it.value(); + } +} + +/** + * Traverse the variable referenced by a jit index and add it to the flat + * variables. An optional type python type can be supplied if it is known. + * Depending on the ``TraverseContext::schedule_force`` the underlying + * variable is either scheduled (``jit_var_schedule``) or force scheduled + * (``jit_var_schedule_force``). If the variable after evaluation is a + * literal, it is directly recorded in the ``layout`` otherwise, it is added + * to the ``variables`` array, allowing the variables to be used when + * recording the frozen function. + */ +void FlatVariables::traverse_jit_index(uint32_t index, TraverseContext &ctx, + nb::handle tp) { + (void) ctx; + Layout &layout = this->layout.emplace_back(); + + if (tp) + layout.type = nb::borrow(tp); + + int rv = 0; + if (ctx.schedule_force) { + // Returns owning reference + index = jit_var_schedule_force(index, &rv); + } else { + // Schedule and create owning reference + rv = jit_var_schedule(index); + jit_var_inc_ref(index); + } + + VarInfo info = jit_set_backend(index); + if (info.state == VarState::Literal) { + // Special case, where the variable is a literal. This should not + // occur, as all literals are made opaque in beforehand, however it + // is nice to have a fallback. + layout.literal = info.literal; + // Store size in index variable, as this is not used for literals + layout.index = info.size; + layout.vt = info.type; + + layout.flags |= (uint32_t) LayoutFlag::Literal; + } else { + layout.index = this->add_jit_index(index); + } + jit_var_dec_ref(index); +} + +/** + * Construct a variable, given it's layout. + * This is the counterpart to `traverse_jit_index`. + * + * Optionally, the index of a variable can be provided that will be + * overwritten with the result of this function. In that case, the function + * will check for compatible variable types. + */ +uint32_t FlatVariables::construct_jit_index(uint32_t prev_index) { + Layout &layout = this->layout[layout_index++]; + + uint32_t index; + VarType vt; + if (layout.flags & (uint32_t) LayoutFlag::Literal) { + index = jit_var_literal(this->backend, layout.vt, &layout.literal, + layout.index); + vt = layout.vt; + } else { + VarLayout &var_layout = this->var_layout[layout.index]; + index = this->variables[layout.index]; + jit_log(LogLevel::Debug, " uses output[%u] = r%u", layout.index, + index); + + jit_var_inc_ref(index); + vt = var_layout.vt; + } + + if (prev_index) { + if (vt != (VarType) jit_var_type(prev_index)) + jit_fail("VarType missmatch %u != %u while assigning (r%u) " + "-> (r%u)!", + (uint32_t) vt, (uint32_t) jit_var_type(prev_index), + (uint32_t) prev_index, (uint32_t) index); + } + return index; +} + +/** + * Add an ad variable by it's index. Both the value and gradient are added + * to the flattened variables. If the ad index has been marked as postponed + * in the \c TraverseContext.postponed field, we mark the resulting layout + * with that flag. This will cause the gradient edges to be propagated when + * assigning to the input. The function takes an optional python-type if + * it is known. + */ +void FlatVariables::traverse_ad_index(uint64_t index, TraverseContext &ctx, + nb::handle tp) { + // NOTE: instead of emplacing a Layout representing the ad variable always, + // we only do so if the gradients have been enabled. We use this format, + // since most variables will not be ad enabled. The layout therefore has to + // be peeked in ``construct_ad_index`` before deciding if an ad or jit + // index should be constructed/assigned. + int grad_enabled = ad_grad_enabled(index); + if (grad_enabled) { + Layout &layout = this->layout.emplace_back(); + uint32_t ad_index = (uint32_t) (index >> 32); + + if (tp) + layout.type = nb::borrow(tp); + layout.num = 2; + // layout.vt = jit_var_type(index); + + // Set flags + layout.flags |= (uint32_t) LayoutFlag::GradEnabled; + // If the edge with this node as it's target has been postponed by + // the isolate gradient scope, it has been enqueued and we mark the + // ad variable as such. + if (ctx.postponed && ctx.postponed->contains(ad_index)) { + layout.flags |= (uint32_t) LayoutFlag::Postponed; + } + + traverse_jit_index((uint32_t) index, ctx, tp); + uint32_t grad = ad_grad(index); + traverse_jit_index(grad, ctx, tp); + jit_var_dec_ref(grad); + } else { + traverse_jit_index(index, ctx, tp); + } +} + +/** + * Construct/assign the variable index given a layout. + * This corresponds to `traverse_ad_index`. + * + * This function is also used for assignment to ad-variables. + * If a `prev_index` is provided, and it is an ad-variable the gradient and + * value of the flat variables will be applied to the ad variable, + * preserving the ad_idnex. + * + * It returns an owning reference. + */ +uint64_t FlatVariables::construct_ad_index(uint32_t shrink, + uint64_t prev_index) { + Layout &layout = this->layout[this->layout_index]; + + uint64_t index; + if ((layout.flags & (uint32_t) LayoutFlag::GradEnabled) != 0) { + Layout &layout = this->layout[this->layout_index++]; + bool postponed = (layout.flags & (uint32_t) LayoutFlag::Postponed); + + uint32_t val = construct_jit_index(prev_index); + uint32_t grad = construct_jit_index(prev_index); + + // Resize the gradient if it is a literal + if ((VarState) jit_var_state(grad) == VarState::Literal) { + uint32_t new_grad = jit_var_resize(grad, jit_var_size(val)); + jit_var_dec_ref(grad); + grad = new_grad; + } + + // If the prev_index variable is provided we assign the new value + // and gradient to the ad variable of that index instead of creating + // a new one. + uint32_t ad_index = (uint32_t) (prev_index >> 32); + if (ad_index) { + index = (((uint64_t) ad_index) << 32) | ((uint64_t) val); + ad_var_inc_ref(index); + } else + index = ad_var_new(val); + + jit_log(LogLevel::Debug, " -> ad_var r%zu", index); + jit_var_dec_ref(val); + + // Equivalent to set_grad + ad_clear_grad(index); + ad_accum_grad(index, grad); + jit_var_dec_ref(grad); + + // Variables, that have been postponed by the isolate gradient scope + // will be enqueued, which propagates their gradeint to previous + // functions. + if (ad_index && postponed) { + ad_enqueue(drjit::ADMode::Backward, index); + } + } else { + index = construct_jit_index(prev_index); + } + + if (shrink > 0) + index = ad_var_shrink(index, shrink); + + return index; +} + +/** + * Wrapper aground traverse_ad_index for a python handle. + */ +void FlatVariables::traverse_ad_var(nb::handle h, TraverseContext &ctx) { + auto s = supp(h.type()); + + if (s.is_class) { + auto variant = nb::borrow(nb::getattr(h, "Variant")); + auto domain = nb::borrow(nb::getattr(h, "Domain")); + add_domain(variant.c_str(), domain.c_str()); + } + + raise_if(s.index == nullptr, "freeze(): ArraySupplement index function " + "pointer is nullptr."); + + uint64_t index = s.index(inst_ptr(h)); + + this->traverse_ad_index(index, ctx, h.type()); +} + +/** + * Construct an ad variable given it's layout. + * This corresponds to `traverse_ad_var` + */ +nb::object FlatVariables::construct_ad_var(const Layout &layout, + uint32_t shrink) { + uint64_t index = construct_ad_index(shrink); + + auto result = nb::inst_alloc_zero(layout.type); + const ArraySupplement &s = supp(result.type()); + s.init_index(index, inst_ptr(result)); + nb::inst_mark_ready(result); + + // We have to release the reference, since assignment will borrow from + // it. + ad_var_dec_ref(index); + + return result; +} + +/** + * Assigns an ad variable. + * Corresponds to `traverse_ad_var`. + * This uses `construct_ad_index` to either construct a new ad variable or + * assign the value and gradient to an already existing one. + */ +void FlatVariables::assign_ad_var(Layout &layout, nb::handle dst) { + const ArraySupplement &s = supp(layout.type); + + uint64_t index; + if (s.index) { + // ``construct_ad_index`` is used for assignment + index = construct_ad_index(0, s.index(inst_ptr(dst))); + } else + index = construct_ad_index(); + + s.reset_index(index, inst_ptr(dst)); + jit_log(LogLevel::Debug, "index=%zu, grad_enabled=%u, ad_grad_enabled=%u", + index, grad_enabled(dst), ad_grad_enabled(index)); + + // Release reference, since ``construct_ad_index`` returns owning + // reference and ``s.reset_index`` borrows from it. + ad_var_dec_ref(index); +} + +/** + * Traverse a c++ tree using it's `traverse_1_cb_ro` callback. + */ +void FlatVariables::traverse_cb(const drjit::TraversableBase *traversable, + TraverseContext &ctx, nb::object type) { + // ProfilerPhase profiler(traversable); + + uint32_t layout_index = this->layout.size(); + Layout &layout = this->layout.emplace_back(); + layout.type = nb::borrow(type); + + struct Payload{ + TraverseContext &ctx; + FlatVariables *flat_variables = nullptr; + uint32_t num_fields = 0; + }; + + Payload p{ctx, this, 0}; + + traversable->traverse_1_cb_ro((void*) & p, + [](void *p, uint64_t index, const char *variant, const char *domain) { + if (!index) + return; + Payload *payload = (Payload *)p; + payload->flat_variables->add_domain(variant, domain); + payload->flat_variables->traverse_ad_index(index, payload->ctx); + payload->num_fields++; + }); + + this->layout[layout_index].num = p.num_fields; +} + +/** + * Helper function, used to assign a callback variable. + * + * \param tmp + * This vector is populated with the indices to variables that have been + * constructed. It is required to release the references, since the + * references created by `construct_ad_index` are owning and they are + * borrowed after the callback returns. + */ +uint64_t FlatVariables::assign_cb_internal(uint64_t index, + index64_vector &tmp) { + if (!index) + return index; + + uint64_t new_index = this->construct_ad_index(0, index); + + tmp.push_back_steal(new_index); + return new_index; +} + +/** + * Assigns variables using it's `traverse_cb_rw` callback. + * This corresponds to `traverse_cb`. + */ +void FlatVariables::assign_cb(drjit::TraversableBase *traversable) { + Layout &layout = this->layout[layout_index++]; + + + struct Payload{ + FlatVariables *flat_variables = nullptr; + Layout &layout; + index64_vector tmp; + uint32_t field_counter = 0; + }; + Payload p{ this, layout, index64_vector(), 0 }; + traversable->traverse_1_cb_rw( + (void *) &p, [](void *p, uint64_t index, const char *, const char *) { + if (!index) + return index; + Payload *payload = (Payload *) p; + if (payload->field_counter >= payload->layout.num) + jit_raise("While traversing an object " + "for assigning inputs, the number of variables to " + "assign (>%u) did not match the number of variables " + "traversed when recording (%u)!", + payload->field_counter, payload->layout.num); + payload->field_counter++; + return payload->flat_variables->assign_cb_internal(index, payload->tmp); + }); + + if (p.field_counter != layout.num) + jit_raise("While traversing and object for assigning inputs, the " + "number of variables to assign did not match the number " + "of variables traversed when recording!"); +} + +/** + * Traverses a PyTree in DFS order, and records it's layout in the + * `layout` vector. + * + * When hitting a drjit primitive type, it calls the + * `traverse_dr_var` method, which will add their indices to the + * `flat_variables` vector. The collect method will also record metadata + * about the drjit variable in the layout. Therefore, the layout can be used + * as an identifier to the recording of the frozen function. + */ +void FlatVariables::traverse(nb::handle h, TraverseContext &ctx) { + recursion_guard guard(this); + + scoped_set_flag traverse_scope(JitFlag::EnableObjectTraversal, true); + + ProfilerPhase profiler("traverse"); + nb::handle tp = h.type(); + + auto tp_name = nb::type_name(tp).c_str(); + jit_log(LogLevel::Debug, "FlatVariables::traverse(): %s {", tp_name); + + try { + uint32_t layout_index = this->layout.size(); + Layout &layout = this->layout.emplace_back(); + layout.type = nb::borrow(tp); + if (is_drjit_type(tp)) { + const ArraySupplement &s = supp(tp); + if (s.is_tensor) { + nb::handle array = s.tensor_array(h.ptr()); + + auto full_shape = nb::borrow(shape(h)); + + // Instead of adding the whole shape of a tensor to the key, we + // only add the inner part, not containing dimension 0. When + // indexing into a tensor, this is the only dimension that is + // not used in the index calculation. When constructing a tensor + // this dimension is reconstructed from the width of the + // underlying array. + + nb::list inner_shape; + if (full_shape.size() > 0) + for (uint32_t i = 1; i < full_shape.size(); i++) { + inner_shape.append(full_shape[i]); + } + + layout.py_object = nb::tuple(inner_shape); + + traverse(nb::steal(array), ctx); + } else if (s.ndim != 1) { + Py_ssize_t len = s.shape[0]; + if (len == DRJIT_DYNAMIC) + len = s.len(inst_ptr(h)); + + layout.num = len; + + for (Py_ssize_t i = 0; i < len; ++i) + traverse(nb::steal(s.item(h.ptr(), i)), ctx); + } else { + traverse_ad_var(h, ctx); + } + } else if (tp.is(&PyTuple_Type)) { + nb::tuple tuple = nb::borrow(h); + + layout.num = tuple.size(); + + for (nb::handle h2 : tuple) { + traverse(h2, ctx); + } + } else if (tp.is(&PyList_Type)) { + nb::list list = nb::borrow(h); + + layout.num = list.size(); + + for (nb::handle h2 : list) { + traverse(h2, ctx); + } + } else if (tp.is(&PyDict_Type)) { + nb::dict dict = nb::borrow(h); + + layout.num = dict.size(); + layout.fields.reserve(layout.num); + for (auto k : dict.keys()) { + layout.fields.push_back(nb::borrow(k)); + } + + for (auto [k, v] : dict) { + traverse(v, ctx); + } + } else if (nb::dict ds = get_drjit_struct(tp); ds.is_valid()) { + + layout.num = ds.size(); + layout.fields.reserve(layout.num); + for (auto k : ds.keys()) { + layout.fields.push_back(nb::borrow(k)); + } + + for (auto [k, v] : ds) { + traverse(nb::getattr(h, k), ctx); + } + } else if (nb::object df = get_dataclass_fields(tp); df.is_valid()) { + + for (auto field : df) { + nb::object k = field.attr(DR_STR(name)); + layout.fields.push_back(nb::borrow(k)); + } + layout.num = layout.fields.size(); + + for (nb::handle field : df) { + nb::object k = field.attr(DR_STR(name)); + traverse(nb::getattr(h, k), ctx); + } + } else if (auto traversable = get_traversable_base(h); traversable) { + traverse_cb(traversable, ctx, nb::borrow(tp)); + } else if (auto cb = get_traverse_cb_ro(tp); cb.is_valid()) { + ProfilerPhase profiler("traverse cb"); + + uint32_t num_fields = 0; + + // Traverse the opaque C++ object + cb(h, nb::cpp_function([&](uint64_t index, const char *variant, + const char *domain) { + if (!index) + return; + add_domain(variant, domain); + num_fields++; + this->traverse_ad_index(index, ctx, nb::none()); + return; + })); + + // Update layout number of fields + this->layout[layout_index].num = num_fields; + } else { + jit_log(LogLevel::Info, + "traverse(): You passed a value of type %s to a frozen " + "function, it could not be converted to a Dr.Jit type. " + "Changing this value in future calls to the frozen " + "function will cause it to be re-traced.", + nb::str(tp).c_str()); + + layout.py_object = nb::borrow(h); + } + } catch (nb::python_error &e) { + nb::raise_from(e, PyExc_RuntimeError, + "FlatVariables::traverse(): error encountered while " + "processing an argument of type '%U' (see above).", + nb::type_name(tp).ptr()); + } catch (const std::exception &e) { + nb::chain_error(PyExc_RuntimeError, + "FlatVariables::traverse(): error encountered " + "while processing an argument of type '%U': %s", + nb::type_name(tp).ptr(), e.what()); + nb::raise_python_error(); + } + + jit_log(LogLevel::Debug, "}"); +} + +/** + * This is the counterpart to the ``traverse`` method, used to construct the + * output of a frozen function. Given a layout vector and flat_variables, it + * re-constructs the PyTree. + */ +nb::object FlatVariables::construct() { + recursion_guard guard(this); + + if (this->layout.size() == 0) { + return nb::none(); + } + + const Layout &layout = this->layout[layout_index++]; + + auto tp_name = nb::type_name(layout.type).c_str(); + jit_log(LogLevel::Debug, "FlatVariables::construct(): %s {", tp_name); + + if (layout.type.is(nb::none().type())) { + return nb::none(); + } + try { + if (is_drjit_type(layout.type)) { + const ArraySupplement &s = supp(layout.type); + if (s.is_tensor) { + nb::object array = construct(); + + // Reconstruct the full shape from the inner part, stored in the + // layout and the width of the underlying array. + auto inner_shape = nb::borrow(layout.py_object); + auto first_dim = prod(shape(array), nb::none()) + .floor_div(prod(inner_shape, nb::none())); + + nb::list full_shape; + full_shape.append(first_dim); + for (uint32_t i = 0; i < inner_shape.size(); i++) { + full_shape.append(inner_shape[i]); + } + + nb::object tensor = layout.type(array, nb::tuple(full_shape)); + return tensor; + } else if (s.ndim != 1) { + auto result = nb::inst_alloc_zero(layout.type); + dr::ArrayBase *p = inst_ptr(result); + size_t size = s.shape[0]; + if (size == DRJIT_DYNAMIC) { + size = s.len(p); + s.init(size, p); + } + for (size_t i = 0; i < size; ++i) { + result[i] = construct(); + } + nb::inst_mark_ready(result); + return result; + } else { + return construct_ad_var(layout); + } + } else if (layout.type.is(&PyTuple_Type)) { + nb::list list; + for (uint32_t i = 0; i < layout.num; ++i) { + list.append(construct()); + } + return nb::tuple(list); + } else if (layout.type.is(&PyList_Type)) { + nb::list list; + for (uint32_t i = 0; i < layout.num; ++i) { + list.append(construct()); + } + return std::move(list); + } else if (layout.type.is(&PyDict_Type)) { + nb::dict dict; + for (auto k : layout.fields) { + dict[k] = construct(); + } + return std::move(dict); + } else if (nb::dict ds = get_drjit_struct(layout.type); ds.is_valid()) { + nb::object tmp = layout.type(); + // TODO: validation against `ds` + for (auto k : layout.fields) { + nb::setattr(tmp, k, construct()); + } + return tmp; + } else if (nb::object df = get_dataclass_fields(layout.type); + df.is_valid()) { + nb::dict dict; + for (auto k : layout.fields) { + dict[k] = construct(); + } + return layout.type(**dict); + } else if (layout.py_object) { + return layout.py_object; + } else { + nb::raise("Tried to construct a variable of type %s that is not " + "constructable!", + nb::type_name(layout.type).c_str()); + } + } catch (nb::python_error &e) { + nb::raise_from(e, PyExc_RuntimeError, + "FlatVariables::construct(): error encountered while " + "processing an argument of type '%U' (see above).", + nb::type_name(layout.type).ptr()); + } catch (const std::exception &e) { + nb::chain_error(PyExc_RuntimeError, + "FlatVariables::construct(): error encountered " + "while processing an argument of type '%U': %s", + nb::type_name(layout.type).ptr(), e.what()); + nb::raise_python_error(); + } + + jit_log(LogLevel::Debug, "}"); +} + +/** + * Assigns the flattened variables to an already existing PyTree. + * This is used when input variables have changed. + */ +void FlatVariables::assign(nb::handle dst) { + recursion_guard guard(this); + scoped_set_flag traverse_scope(JitFlag::EnableObjectTraversal, true); + + nb::handle tp = dst.type(); + Layout &layout = this->layout[layout_index++]; + + jit_log(LogLevel::Debug, "FlatVariables::assign(): %s with %s {", + nb::type_name(tp).c_str(), nb::type_name(layout.type).c_str()); + + if (!layout.type.equal(tp)) + jit_fail( + "Type missmatch! Type of the object when recording (%s) does not " + "match type of object that is assigned (%s).", + nb::type_name(tp).c_str(), nb::type_name(layout.type).c_str()); + + try { + if (is_drjit_type(tp)) { + const ArraySupplement &s = supp(tp); + + if (s.is_tensor) { + nb::handle array = s.tensor_array(dst.ptr()); + assign(nb::steal(array)); + } else if (s.ndim != 1) { + Py_ssize_t len = s.shape[0]; + if (len == DRJIT_DYNAMIC) + len = s.len(inst_ptr(dst)); + + for (Py_ssize_t i = 0; i < len; ++i) + assign(dst[i]); + } else { + assign_ad_var(layout, dst); + } + } else if (tp.is(&PyTuple_Type)) { + nb::tuple tuple = nb::borrow(dst); + raise_if( + tuple.size() != layout.num, + "The number of objects in this tuple changed from %u to %u " + "while recording the function.", + layout.num, (uint32_t) tuple.size()); + + for (nb::handle h2 : tuple) + assign(h2); + } else if (tp.is(&PyList_Type)) { + nb::list list = nb::borrow(dst); + raise_if(list.size() != layout.num, + "The number of objects in this list changed from %u to %u " + "while recording the function.", + layout.num, (uint32_t) list.size()); + + for (nb::handle h2 : list) + assign(h2); + } else if (tp.is(&PyDict_Type)) { + nb::dict dict = nb::borrow(dst); + for (auto &k : layout.fields) { + if (dict.contains(&k)) + assign(dict[k]); + else + dst[k] = construct(); + } + } else if (nb::dict ds = get_drjit_struct(dst); ds.is_valid()) { + for (auto &k : layout.fields) { + if (nb::hasattr(dst, k)) + assign(nb::getattr(dst, k)); + else + nb::setattr(dst, k, construct()); + } + } else if (nb::object df = get_dataclass_fields(tp); df.is_valid()) { + for (auto k : layout.fields) { + if (nb::hasattr(dst, k)) + assign(nb::getattr(dst, k)); + else + nb::setattr(dst, k, construct()); + } + } else if (auto traversable = get_traversable_base(dst); traversable) { + assign_cb(traversable); + } else if (nb::object cb = get_traverse_cb_rw(tp); cb.is_valid()) { + index64_vector tmp; + uint32_t num_fields = 0; + + cb(dst, nb::cpp_function([&](uint64_t index, const char *, + const char *) { + if (!index) + return index; + jit_log(LogLevel::Debug, + "assign(): traverse_cb[%u] was a%u r%u", num_fields, + (uint32_t) (index >> 32), (uint32_t) index); + num_fields++; + if (num_fields > layout.num) + jit_raise( + "While traversing the object of type %s " + "for assigning inputs, the number of variables " + "to assign (>%u) did not match the number of " + "variables traversed when recording(%u)!", + nb::str(tp).c_str(), num_fields, layout.num); + return assign_cb_internal(index, tmp); + })); + if (num_fields != layout.num) + jit_raise("While traversing the object of type %s " + "for assigning inputs, the number of variables " + "to assign did not match the number of variables " + "traversed when recording!", + nb::str(tp).c_str()); + } else { + } + } catch (nb::python_error &e) { + nb::raise_from(e, PyExc_RuntimeError, + "FlatVariables::assign(): error encountered while " + "processing an argument " + "of type '%U' (see above).", + nb::type_name(tp).ptr()); + } catch (const std::exception &e) { + nb::chain_error(PyExc_RuntimeError, + "FlatVariables::assign(): error encountered " + "while processing an argument " + "of type '%U': %s", + nb::type_name(tp).ptr(), e.what()); + nb::raise_python_error(); + } + + jit_log(LogLevel::Debug, "}"); +} + +/** + * First traverses the PyTree, then the registry. This ensures that + * additional data to vcalls is tracked correctly. + */ +void FlatVariables::traverse_with_registry(nb::handle h, TraverseContext &ctx) { + scoped_set_flag traverse_scope(JitFlag::EnableObjectTraversal, true); + + // Traverse the handle + traverse(h, ctx); + + // Traverse the registry (if a class variable was traversed) + if (!domains.empty()) { + ProfilerPhase profiler("traverse_registry"); + uint32_t layout_index = this->layout.size(); + Layout &layout = this->layout.emplace_back(); + layout.type = nb::borrow(nb::none()); + + uint32_t num_fields = 0; + + jit_log(LogLevel::Debug, "registry{"); + + std::vector registry_pointers; + for (std::string &domain : domains) { + uint32_t registry_bound = + jit_registry_id_bound(variant.c_str(), domain.c_str()); + uint32_t offset = registry_pointers.size(); + registry_pointers.resize(registry_pointers.size() + registry_bound, nullptr); + jit_registry_get_pointers(variant.c_str(), domain.c_str(), + ®istry_pointers[offset]); + } + + jit_log(LogLevel::Debug, "registry_bound=%u", registry_pointers.size()); + jit_log(LogLevel::Debug, "layout_index=%u", this->layout.size()); + for (void *ptr : registry_pointers) { + jit_log(LogLevel::Debug, "ptr=%p", ptr); + if (!ptr) + continue; + + // WARN: very unsafe cast! + // We assume, that any object added to the registry inherits from + // TraversableBase. This is ensured by the signature of the + // ``drjit::registry_put`` function. + auto traversable = (drjit::TraversableBase *) ptr; + auto self = traversable->self_py(); + + if (self) + traverse(self, ctx); + else + traverse_cb(traversable, ctx); + + num_fields++; + } + jit_log(LogLevel::Debug, "}"); + + this->layout[layout_index].num = num_fields; + } +} + +/** + * First assigns the registry and then the PyTree. + * Corresponds to `traverse_with_registry`. + */ +void FlatVariables::assign_with_registry(nb::handle dst) { + scoped_set_flag traverse_scope(JitFlag::EnableObjectTraversal, true); + + // Assign the handle + assign(dst); + + // Assign registry (if a class variable was traversed) + if (!domains.empty()) { + Layout &layout = this->layout[layout_index++]; + uint32_t num_fields = 0; + + jit_log(LogLevel::Debug, "registry{"); + + std::vector registry_pointers; + for (std::string &domain : domains) { + uint32_t registry_bound = + jit_registry_id_bound(variant.c_str(), domain.c_str()); + uint32_t offset = registry_pointers.size(); + registry_pointers.resize(registry_pointers.size() + registry_bound, nullptr); + jit_registry_get_pointers(variant.c_str(), domain.c_str(), + ®istry_pointers[offset]); + } + + jit_log(LogLevel::Debug, "registry_bound=%u", registry_pointers.size()); + jit_log(LogLevel::Debug, "layout_index=%u", this->layout_index); + for (void *ptr : registry_pointers) { + jit_log(LogLevel::Debug, "ptr=%p", ptr); + if (!ptr) + continue; + + // WARN: very unsafe cast! + // We assume, that any object added to the registry inherits from + // TraversableBase. This is ensured by the signature of the + // ``drjit::registry_put`` function. + auto traversable = (drjit::TraversableBase *) ptr; + auto self = traversable->self_py(); + + if (self) + assign(self); + else + assign_cb(traversable); + + num_fields++; + } + jit_log(LogLevel::Debug, "}"); + } +} + +inline void hash_combine(size_t &seed, size_t value) { + /// From CityHash (https://github.com/google/cityhash) + const size_t mult = 0x9ddfea08eb382d69ull; + size_t a = (value ^ seed) * mult; + a ^= (a >> 47); + size_t b = (seed ^ a) * mult; + b ^= (b >> 47); + seed = b * mult; +} + +size_t +FlatVariablesHasher::operator()(const std::shared_ptr &key) const { + ProfilerPhase profiler("hash"); + // Hash the layout + // NOTE: string hashing seems to be less efficient + size_t hash = key->layout.size(); + for (const Layout &layout : key->layout) { + hash_combine(hash, layout.num); + hash_combine(hash, layout.fields.size()); + hash_combine(hash, (size_t) layout.flags); + hash_combine(hash, (size_t) layout.index); + hash_combine(hash, (size_t) layout.literal); + hash_combine(hash, (size_t) layout.vt); + if (layout.type) + hash_combine(hash, nb::hash(layout.type)); + if (layout.py_object) + hash_combine(hash, nb::hash(layout.py_object)); + for (auto &field : layout.fields) { + hash_combine(hash, nb::hash(field)); + } + } + + for (const VarLayout &layout : key->var_layout) { + hash_combine(hash, (size_t) layout.vt); + hash_combine(hash, (size_t) layout.vs); + hash_combine(hash, (size_t) layout.flags); + hash_combine(hash, (size_t) layout.size_index); + } + + hash_combine(hash, (size_t) key->flags); + + return hash; +} + +/* + * Record a function, given it's python input and flattened input. + */ +nb::object FunctionRecording::record(nb::callable func, + FrozenFunction *frozen_func, + nb::list input, + const FlatVariables &in_variables) { + ProfilerPhase profiler("record"); + JitBackend backend = in_variables.backend; + + frozen_func->recording_counter++; + if (frozen_func->recording_counter > frozen_func->warn_recording_count)jit_log( + LogLevel::Warn, + "The frozen function has been recorded %u times, this indicates a " + "problem with how the frozen function is being called. For " + "example, calling it with changing python values such as a " + "index.", + frozen_func->recording_counter); + + + jit_log(LogLevel::Info, + "Recording (n_inputs=%u):", in_variables.variables.size()); + jit_freeze_start(backend, in_variables.variables.data(), + in_variables.variables.size()); + + // Record the function + nb::object output; + { + ProfilerPhase profiler("function"); + output = func(*input[0], **input[1]); + } + + // Collect nodes, that have been postponed by the `Isolate` scope in a + // hash set. + // These are the targets of postponed edges, as the isolate gradient + // scope only handles backward mode differentiation. + // If they are, then we have to enqueue them when replaying the + // recording. + tsl::robin_set postponed; + { + drjit::vector postponed_vec; + ad_scope_postponed(&postponed_vec); + for (uint32_t index : postponed_vec) + postponed.insert(index); + } + + jit_log(LogLevel::Info, "Traversing output"); + { + ProfilerPhase profiler("traverse output"); + // Enter Resume scope, so we can track gradients + ADScopeContext ad_scope(drjit::ADScope::Resume, 0, nullptr, -1, false); + + TraverseContext ctx; + ctx.postponed = &postponed; + ctx.schedule_force = false; + out_variables.traverse(output, ctx); + ctx.schedule_force = true; + out_variables.traverse_with_registry(input, ctx); + + { // Evaluate the variables, scheduled when traversing + nb::gil_scoped_release guard; + jit_eval(); + } + + out_variables.record_jit_variables(); + } + + jit_freeze_pause(backend); + + if ((out_variables.variables.size() > 0 && + in_variables.variables.size() > 0) && + out_variables.backend != backend) { + Recording *recording = jit_freeze_stop(backend, nullptr, 0); + jit_freeze_destroy(recording); + + nb::raise( + "freeze(): backend missmatch error (backend %u of " + "output variables did not match backend %u of input variables)", + (uint32_t) out_variables.backend, (uint32_t) backend); + } + + // Exceptions, thrown by the recording functions will be recorded and + // re-thrown when calling ``jit_freeze_stop``. Since the output variables + // are borrowed, we have to release them in that case, and catch these + // exceptions. + try { + recording = jit_freeze_stop(backend, out_variables.variables.data(), + out_variables.variables.size()); + } catch (nb::python_error &e) { + out_variables.release(); + nb::raise_from(e, PyExc_RuntimeError, + "record(): error encountered while recording a function " + "(see above)."); + } catch (const std::exception &e) { + out_variables.release(); + nb::chain_error(PyExc_RuntimeError, "record(): %s", e.what()); + nb::raise_python_error(); + } + + jit_log(LogLevel::Info, "Recording done (n_outputs=%u)", + out_variables.variables.size()); + + // For catching input assignment mismatches, we assign the input and + // output + { + // Enter Resume scope, so we can track gradients + ADScopeContext ad_scope(drjit::ADScope::Resume, 0, nullptr, -1, false); + + out_variables.layout_index = 0; + jit_log(LogLevel::Debug, "Construct:"); + output = nb::borrow(out_variables.construct()); + // NOTE: temporarily disable this to not enqueue twice + out_variables.assign(input); + out_variables.layout_index = 0; + } + + // Traversal takes owning references, so here we need to release them. + out_variables.release(); + + return output; +} +/* + * Replays the recording. + * + * This constructs the output and re-assigns the input. + */ +nb::object FunctionRecording::replay(nb::callable func, + FrozenFunction *frozen_func, + nb::list input, + const FlatVariables &in_variables) { + ProfilerPhase profiler("replay"); + + jit_log(LogLevel::Info, "Replaying:"); + int dryrun_success; + { + ProfilerPhase profiler("dry run"); + dryrun_success = + jit_freeze_dry_run(recording, in_variables.variables.data()); + } + if (!dryrun_success) { + // Dry run has failed. Re-record the function. + jit_log(LogLevel::Info, "Dry run failed! re-recording"); + this->clear(); + try { + return this->record(func, frozen_func, input, in_variables); + } catch (nb::python_error &e) { + nb::raise_from(e, PyExc_RuntimeError, + "replay(): error encountered while re-recording a " + "function (see above)."); + } catch (const std::exception &e) { + jit_freeze_abort(in_variables.backend); + + nb::chain_error(PyExc_RuntimeError, "record(): %s", e.what()); + nb::raise_python_error(); + } + } else { + ProfilerPhase profiler("jit replay"); + nb::gil_scoped_release guard; + jit_freeze_replay(recording, in_variables.variables.data(), + out_variables.variables.data()); + } + jit_log(LogLevel::Info, "Replaying done:"); + + // Construct Output variables + nb::object output; + { + // Enter Resume scope, so we can track gradients + ADScopeContext ad_scope(drjit::ADScope::Resume, 0, nullptr, -1, false); + out_variables.layout_index = 0; + { + ProfilerPhase profiler("construct output"); + output = nb::borrow(out_variables.construct()); + } + { + ProfilerPhase profiler("assign input"); + out_variables.assign_with_registry(input); + } + } + + // out_variables is assigned by ``jit_record_replay``, which transfers + // ownership to this array. Therefore, we have to drop the variables + // afterwards. + out_variables.release(); + + return output; +} + +nb::object FrozenFunction::operator()(nb::args args, nb::kwargs kwargs) { + nb::object result; + { + // Enter Isolate grad scope, so that gradients are not propagated + // outside of the function scope. + ADScopeContext ad_scope(drjit::ADScope::Isolate, 0, nullptr, -1, true); + + // Kernel freezing can be enabled or disabled with the + // ``JitFlag::KernelFreezing``. Alternatively, when calling a frozen + // function from another one, we simply record the inner function. + if (!jit_flag(JitFlag::KernelFreezing) || + jit_flag(JitFlag::FreezingScope) || max_cache_size == 0) { + ProfilerPhase profiler("function"); + return func(*args, **kwargs); + } + + call_counter++; + + nb::list input; + input.append(args); + input.append(kwargs); + + auto in_variables = + std::make_shared(FlatVariables(in_heuristics)); + in_variables->backend = this->default_backend; + // Evaluate and traverse input variables (args and kwargs) + { + // Enter Resume scope, so we can track gradients + ADScopeContext ad_scope(drjit::ADScope::Resume, 0, nullptr, 0, + true); + + // Traverse input variables + ProfilerPhase profiler("traverse input"); + jit_log(LogLevel::Info, "freeze(): Traversing input"); + TraverseContext ctx; + ctx.schedule_force = true; + in_variables->traverse_with_registry(input, ctx); + + { // Evaluate the variables, scheduled when traversing + nb::gil_scoped_release guard; + jit_eval(); + } + + in_variables->record_jit_variables(); + } + + in_heuristics = in_heuristics.max(in_variables->heuristic()); + + raise_if(in_variables->backend == JitBackend::None, + "freeze(): Cannot infer backend without providing input " + "variable to frozen function!"); + + auto it = this->recordings.find(in_variables); + + // Evict the least recently used recording if the cache is "full" + if (max_cache_size > 0 && + recordings.size() >= (uint32_t) max_cache_size && + it == this->recordings.end()) { + + uint32_t lru_last_used = UINT32_MAX; + RecordingMap::iterator lru_it = recordings.begin(); + + for (auto it = recordings.begin(); it != recordings.end(); it++) { + auto &recording = it.value(); + if (recording->last_used < lru_last_used) { + lru_last_used = recording->last_used; + lru_it = it; + } + } + recordings.erase(lru_it); + + it = this->recordings.find(in_variables); + } + + if (it == this->recordings.end()) { + { + // TODO: single traverse + ADScopeContext ad_scope(drjit::ADScope::Resume, 0, nullptr, 0, + true); + in_variables->assign_with_registry(input); + } + + // FunctionRecording recording; + auto recording = std::make_unique(); + recording->last_used = call_counter - 1; + + try { + result = recording->record(func, this, input, *in_variables); + } catch (nb::python_error &e) { + in_variables->release(); + jit_freeze_abort(in_variables->backend); + nb::raise_from( + e, PyExc_RuntimeError, + "record(): error encountered while recording a frozen " + "function (see above)."); + } catch (const std::exception &e) { + in_variables->release(); + jit_freeze_abort(in_variables->backend); + + nb::chain_error(PyExc_RuntimeError, "record(): %s", e.what()); + nb::raise_python_error(); + }; + + in_variables->release(); + + this->prev_key = in_variables; + this->recordings.insert( + { std::move(in_variables), std::move(recording) }); + + } else { + FunctionRecording *recording = it.value().get(); + + recording->last_used = call_counter - 1; + + { + result = recording->replay(func, this, input, *in_variables); + } + + // Drop references to variables + in_variables->release(); + } + } + ad_traverse(drjit::ADMode::Backward, + (uint32_t) drjit::ADFlag::ClearVertices); + return result; +} + +void FrozenFunction::clear() { + recordings.clear(); + prev_key = std::make_shared(FlatVariables()); + recording_counter = 0; + call_counter = 0; +} + +/** + * This function inspects the content of the frozen function to detect reference + * cycles, that could lead to memory or type leaks. It can be called by the + * garbage collector by adding it to the ``type_slots`` of the + * ``FrozenFunction`` definition. + */ +int frozen_function_tp_traverse(PyObject *self, visitproc visit, void *arg) { + FrozenFunction *f = nb::inst_ptr(self); + + nb::handle func = nb::find(f->func); + Py_VISIT(func.ptr()); + + for (auto &it : f->recordings) { + for (auto &layout : it.first->layout) { + nb::handle type = nb::find(layout.type); + Py_VISIT(type.ptr()); + nb::handle object = nb::find(layout.py_object); + Py_VISIT(object.ptr()); + } + for (auto &layout : it.second->out_variables.layout) { + nb::handle type = nb::find(layout.type); + Py_VISIT(type.ptr()); + nb::handle object = nb::find(layout.py_object); + Py_VISIT(object.ptr()); + } + } + + return 0; +} + +/** + * This function releases the internal function of the ``FrozenFunction`` + * object. It is used by the garbage collector to "break" potential reference + * cycles, resulting from the frozen function being referenced in the closure of + * the wrapped variable. + */ +int frozen_function_clear(PyObject *self) { + FrozenFunction *f = nb::inst_ptr(self); + + f->func.release(); + + return 0; +} + +// Slot data structure referencing the above two functions +static PyType_Slot slots[] = { { Py_tp_traverse, + (void *) frozen_function_tp_traverse }, + { Py_tp_clear, (void *) frozen_function_clear }, + { 0, nullptr } }; + +void export_freeze(nb::module_ & /*m*/) { + + nb::module_ d = nb::module_::import_("drjit.detail"); + auto traversable_base = + nb::class_(d, "TraversableBase"); + nb::class_(d, "FrozenFunction", nb::type_slots(slots)) + .def(nb::init()) + .def_prop_ro( + "n_cached_recordings", + [](FrozenFunction &self) { return self.n_cached_recordings(); }) + .def_ro("n_recordings", &FrozenFunction::recording_counter) + .def("clear", &FrozenFunction::clear) + .def("__call__", &FrozenFunction::operator()); +} diff --git a/src/python/freeze.h b/src/python/freeze.h new file mode 100644 index 00000000..b02cf643 --- /dev/null +++ b/src/python/freeze.h @@ -0,0 +1,529 @@ +/* + freeze.h -- Bindings for drjit.freeze() + + Dr.Jit: A Just-In-Time-Compiler for Differentiable Rendering + Copyright 2023, Realistic Graphics Lab, EPFL. + + All rights reserved. Use of this source code is governed by a + BSD-style license that can be found in the LICENSE.txt file. +*/ + +#pragma once + +#include "common.h" +#include +#include +#include +#include +#include +#include + +struct FrozenFunction; + +namespace detail { + +using index64_vector = drjit::detail::index64_vector; + +enum class LayoutFlag : uint32_t { + /// Whether this variable has size 1 + SingletonArray = (1 << 0), + /// Whether this variable is unaligned in memory + Unaligned = (1 << 1), + /// Whether this layout represents a literal variable + Literal = (1 << 2), + /// Whether this variable has gradients enabled + GradEnabled = (1 << 3), + /// Did this variable have gradient edges attached when recording, that + /// where postponed by the ``isolate_grad`` function? + Postponed = (1 << 4), +}; + +/// Stores information about python objects, such as their type, their number of +/// sub-elements or their field keys. This can be used to reconstruct a PyTree +/// from a flattened variable array. +struct Layout { + /// Number of members in this container. + /// Can be used to traverse the layout without knowing the type. + uint32_t num = 0; + /// Optional field identifiers of the container + /// for example: keys in dictionary + std::vector fields; + /// The index in the flat_variables array of this variable. + /// This can be used to determine aliasing. + uint32_t index = 0; + + /// Flags, storing information about variables and literals. + uint32_t flags = 0; + + /// The literal data + uint64_t literal = 0; + /// Optional drjit type of the variable + VarType vt = VarType::Void; + + /// If a non drjit type is passed as function arguments or result, we simply + /// cache it here. + /// TODO: possibly do the same for literals? + nb::object py_object; + + /// Nanobind type of the container/variable + nb::type_object type; + + bool operator==(const Layout &rhs) const; + + Layout() = default; + + Layout(const Layout &) = delete; + Layout &operator=(const Layout &) = delete; + + Layout(Layout &&) = default; + Layout &operator=(Layout &&) = default; +}; + +/** + * \brief Stores information about opaque variables. + * + * When traversing a PyTree, literal variables are stored directly and + * non-literal variables are first scheduled and their indices deduplicated and + * added to the ``FlatVariables::variables`` field. After calling ``jit_eval``, + * information about variables can be recorded using + * ``FlatVariables::record_jit_variables``. This struct stores that information + * per deduplicated variable. + */ +struct VarLayout{ + /// Optional drjit type of the variable + VarType vt = VarType::Void; + /// Optional evaluation state of the variable + VarState vs = VarState::Invalid; + /// Flags, storing information about variables + uint32_t flags = 0; + /// We have to track the condition, where two variables have the same size + /// during recording but don't when replaying. + /// Therefore we de-duplicate the size. + uint32_t size_index = 0; + + VarLayout() = default; + + VarLayout(const VarLayout &) = delete; + VarLayout &operator=(const VarLayout &) = delete; + + VarLayout(VarLayout &&) = default; + VarLayout &operator=(VarLayout &&) = default; + + bool operator==(const VarLayout &rhs) const; +}; + + +// Additional context required when traversing the inputs +struct TraverseContext { + /// Set of postponed ad nodes, used to mark inputs to functions. + const tsl::robin_set *postponed = nullptr; + bool schedule_force = false; +}; + +/** + * \brief A flattened representation of the PyTree. + * + * This struct stores a flattened representation of a PyTree as well a + * representation of it. It can therefore be used to either construct the PyTree + * as well as assign the variables to an existing PyTree. Furthermore, this + * struct can also be used as a key to the ``RecordingMap``, determining which + * recording should be used given an input to a frozen function. + * Information about the PyTree is stored in DFS Encoding. Every node of the + * tree is represented by a ``Layout`` element in the ``layout`` vector. + */ +struct FlatVariables { + + uint32_t flags = 0; + + // Index, used to iterate over the variables/layouts when constructing + // python objects + uint32_t layout_index = 0; + + /// The flattened and de-duplicated variable indices of the input/output to + /// a frozen function + std::vector variables; + /// Mapping from drjit jit index to index in flat variables. Used to + /// deduplicate jit indices. + tsl::robin_map index_to_slot; + + /// We have to track the condition, where two variables have the same size + /// during recording but don't when replaying. + /// Therefore we construct equivalence classes of sizes. + /// This vector represents the different sizes, encountered during + /// traversal. The algorithm used to "add" a size is the same as for adding + /// a variable index. + std::vector sizes; + /// Mapping from the size to its index in the ``sizes`` vector. This is used + /// to construct size equivalence classes (i.e. deduplicating sizes). + tsl::robin_map size_to_slot; + + /// This saves information about the type, size and fields of pytree + /// objects. The information is stored in DFS order. + std::vector layout; + /// Stores information about non-literal jit variables. + std::vector var_layout; + /// The collective backend for all input variables. It can be used to ensure + /// that all variables have the same backend. + JitBackend backend = JitBackend::None; + /// The variant, if any, used to traverse the registry. + std::string variant; + /// All domains (deduplicated), encountered while traversing the PyTree and + /// its C++ objects. This can be used to traverse the registry. We use a + /// vector instead of a hash set, since we expect the number of domains not + /// to exceed 100. + std::vector domains; + + uint32_t recursion_level = 0; + + struct recursion_guard { + FlatVariables *flat_variables; + recursion_guard(FlatVariables *flat_variables) + : flat_variables(flat_variables) { + if (++flat_variables->recursion_level >= 50) { + PyErr_SetString(PyExc_RecursionError, + "runaway recursion detected"); + nb::raise_python_error(); + } + } + ~recursion_guard() { flat_variables->recursion_level--; } +}; + + + /** + * Describes how many elements have to be pre-allocated for the ``layout``, + * ``index_to_slot`` and ``size_to_slot`` containers. + */ + struct Heuristic { + size_t layout = 0; + size_t index_to_slot = 0; + size_t size_to_slot = 0; + + Heuristic max(Heuristic rhs){ + return Heuristic{ + std::max(layout, rhs.layout), + std::max(index_to_slot, rhs.index_to_slot), + std::max(size_to_slot, rhs.size_to_slot), + }; + } + }; + + FlatVariables() {} + FlatVariables(Heuristic heuristic) { + layout.reserve(heuristic.layout); + index_to_slot.reserve(heuristic.index_to_slot); + size_to_slot.reserve(heuristic.size_to_slot); + } + + FlatVariables(const FlatVariables &) = delete; + FlatVariables &operator=(const FlatVariables &) = delete; + + FlatVariables(FlatVariables &&) = default; + FlatVariables &operator=(FlatVariables &&) = default; + + void clear() { + this->layout_index = 0; + this->variables.clear(); + this->index_to_slot.clear(); + this->layout.clear(); + this->backend = JitBackend::None; + } + /// Borrow all variables held by this struct. + void borrow() { + for (uint32_t &index : this->variables) + jit_var_inc_ref(index); + } + /// Release all variables held by this struct. + void release() { + for (uint32_t &index : this->variables) + jit_var_dec_ref(index); + } + + /** + * \brief Records information about jit variables, that have been traversed. + * + * After traversing the PyTree, collecting non-literal indices in + * ``variables`` and evaluating the collected indices, we can collect + * information about the underlying variables. This information is used in + * the key of the ``RecordingMap`` to determine which recording should be + * replayed or if the function has to be re-traced. This function iterates + * over the collected indices and collects that information. + */ + void record_jit_variables(); + + /** + * Returns a struct representing heuristics to pre-allocate memory for the + * layout, of the flat variables. + */ + Heuristic heuristic() { + return Heuristic{ + layout.size(), + index_to_slot.size(), + size_to_slot.size(), + }; + }; + + /** + * \brief Add a variant domain pair to be traversed using the registry. + * + * When traversing a jit variable, that references a pointer to a class, + * such as a BSDF or Shape in Mitsuba, we have to traverse all objects + * registered with that variant-domain pair in the registry. This function + * adds the variant-domain pair, deduplicating the domain. Whether a + * variable references a class is represented by it's ``IsClass`` const + * attribute. If the domain is an empty string (""), this function skips + * adding the variant-domain pair. + */ + void add_domain(const char *variant, const char *domain); + + /** + * Adds a jit index to the flattened array, deduplicating it. + * This allows to check for aliasing conditions, where two variables + * actually refer to the same index. The function should only be called for + * scheduled non-literal variable indices. + */ + uint32_t add_jit_index(uint32_t variable_index); + + /** + * This function returns an index into the ``sizes`` vector, representing an + * equivalence class of variable sizes. It uses a HashMap and vector to + * deduplicate sizes. + * + * This is necessary, to catch cases, where two variables had the same size + * when recording a function and two different sizes when replaying. + * In that case one kernel would be recorded, that evaluates both variables. + * However, when replaying two kernels would have to be launched since the + * now differently sized variables cannot be evaluated by the same kernel. + */ + uint32_t add_size(uint32_t size); + + /** + * Traverse the variable referenced by a jit index and add it to the flat + * variables. An optional type python type can be supplied if it is known. + * Depending on the ``TraverseContext::schedule_force`` the underlying + * variable is either scheduled (``jit_var_schedule``) or force scheduled + * (``jit_var_schedule_force``). If the variable after evaluation is a + * literal, it is directly recorded in the ``layout`` otherwise, it is added + * to the ``variables`` array, allowing the variables to be used when + * recording the frozen function. + */ + void traverse_jit_index(uint32_t index, TraverseContext &ctx, + nb::handle tp = nullptr); + /** + * Add an ad variable by it's index. Both the value and gradient are added + * to the flattened variables. If the ad index has been marked as postponed + * in the \c TraverseContext.postponed field, we mark the resulting layout + * with that flag. This will cause the gradient edges to be propagated when + * assigning to the input. The function takes an optional python-type if + * it is known. + */ + void traverse_ad_index(uint64_t index, TraverseContext &ctx, + nb::handle tp = nullptr); + + /** + * Wrapper aground traverse_ad_index for a python handle. + */ + void traverse_ad_var(nb::handle h, TraverseContext &ctx); + + /** + * Traverse a c++ tree using it's `traverse_1_cb_ro` callback. + */ + void traverse_cb(const drjit::TraversableBase *traversable, + TraverseContext &ctx, nb::object type = nb::none()); + + /** + * Traverses a PyTree in DFS order, and records it's layout in the + * `layout` vector. + * + * When hitting a drjit primitive type, it calls the + * `traverse_dr_var` method, which will add their indices to the + * `flat_variables` vector. The collect method will also record metadata + * about the drjit variable in the layout. Therefore, the layout can be used + * as an identifier to the recording of the frozen function. + */ + void traverse(nb::handle h, TraverseContext &ctx); + + /** + * First traverses the PyTree, then the registry. This ensures that + * additional data to vcalls is tracked correctly. + */ + void traverse_with_registry(nb::handle h, TraverseContext &ctx); + + /** + * Construct a variable, given it's layout. + * This is the counterpart to `traverse_jit_index`. + * + * Optionally, the index of a variable can be provided that will be + * overwritten with the result of this function. In that case, the function + * will check for compatible variable types. + */ + uint32_t construct_jit_index(uint32_t prev_index = 0); + + /** + * Construct/assign the variable index given a layout. + * This corresponds to `traverse_ad_index`> + * + * This function is also used for assignment to ad-variables. + * If a `prev_index` is provided, and it is an ad-variable the gradient and + * value of the flat variables will be applied to the ad variable, + * preserving the ad_idnex. + * + * It returns an owning reference. + */ + uint64_t construct_ad_index(uint32_t shrink = 0, + uint64_t prev_index = 0); + + /** + * Construct an ad variable given it's layout. + * This corresponds to `traverse_ad_var` + */ + nb::object construct_ad_var(const Layout &layout, uint32_t shrink = 0); + + /** + * This is the counterpart to the traverse method, used to construct the + * output of a frozen function. Given a layout vector and flat_variables, it + * re-constructs the PyTree. + */ + nb::object construct(); + + /** + * Assigns an ad variable. + * Corresponds to `traverse_ad_var`. + * This uses `construct_ad_index` to either construct a new ad variable or + * assign the value and gradient to an already existing one. + */ + void assign_ad_var(Layout &layout, nb::handle dst); + + /** + * Helper function, used to assign a callback variable. + * + * \param tmp + * This vector is populated with the indices to variables that have been + * constructed. It is required to release the references, since the + * references created by `construct_ad_index` are owning and they are + * borrowed after the callback returns. + */ + uint64_t assign_cb_internal(uint64_t index, index64_vector &tmp); + + /** + * Assigns variables using it's `traverse_cb_rw` callback. + * This corresponds to `traverse_cb`. + */ + void assign_cb(drjit::TraversableBase *traversable); + + /** + * Assigns the flattened variables to an already existing PyTree. + * This is used when input variables have changed. + */ + void assign(nb::handle dst); + + /** + * First assigns the registry and then the PyTree. + * Corresponds to `traverse_with_registry`. + */ + void assign_with_registry(nb::handle dst); + + bool operator==(const FlatVariables &rhs) const { + return this->layout == rhs.layout && + this->var_layout == rhs.var_layout && this->flags == rhs.flags; + } +}; + +/// Helper struct to hash input variables +struct FlatVariablesHasher { + size_t operator()(const std::shared_ptr &key) const; +}; + +/// Helper struct to compare input variables +struct FlatVariablesEqual{ + using is_transparent = void; + bool operator()(const std::shared_ptr &lhs, + const std::shared_ptr &rhs) const { + return *lhs.get() == *rhs.get(); + } +}; + +/** + * \brief A recording of a frozen function, recorded with a certain layout of + * input variables. + */ +struct FunctionRecording { + uint32_t last_used = 0; + Recording *recording = nullptr; + FlatVariables out_variables; + + FunctionRecording() : out_variables() {} + FunctionRecording(const FunctionRecording &) = delete; + FunctionRecording &operator=(const FunctionRecording &) = delete; + FunctionRecording(FunctionRecording &&) = default; + FunctionRecording &operator=(FunctionRecording &&) = default; + + ~FunctionRecording() { + if (this->recording) { + jit_freeze_destroy(this->recording); + } + this->recording = nullptr; + } + + void clear() { + if (this->recording) { + jit_freeze_destroy(this->recording); + } + this->recording = nullptr; + this->out_variables = FlatVariables(); + } + + /* + * Record a function, given it's python input and flattened input. + */ + nb::object record(nb::callable func, FrozenFunction *frozen_func, + nb::list input, const FlatVariables &in_variables); + /* + * Replays the recording. + * + * This constructs the output and re-assigns the input. + */ + nb::object replay(nb::callable func, FrozenFunction *frozen_func, + nb::list input, const FlatVariables &in_variables); +}; + +using RecordingMap = tsl::robin_map, + std::unique_ptr, + FlatVariablesHasher, FlatVariablesEqual>; + +} // namespace detail + +struct FrozenFunction { + nb::callable func; + + detail::RecordingMap recordings; + std::shared_ptr prev_key; + + uint32_t recording_counter = 0; + uint32_t call_counter = 0; + int max_cache_size = -1; + uint32_t warn_recording_count = 10; + JitBackend default_backend = JitBackend::None; + + detail::FlatVariables::Heuristic in_heuristics; + + FrozenFunction(nb::callable func, int max_cache_size = -1, + uint32_t warn_recording_count = 10, + JitBackend backend = JitBackend::None) + : func(func), max_cache_size(max_cache_size), + warn_recording_count(warn_recording_count), default_backend(backend) { + } + ~FrozenFunction() {} + + FrozenFunction(const FrozenFunction &) = delete; + FrozenFunction &operator=(const FrozenFunction &) = delete; + FrozenFunction(FrozenFunction &&) = default; + FrozenFunction &operator=(FrozenFunction &&) = default; + + uint32_t n_cached_recordings() { return this->recordings.size(); } + + void clear(); + + nb::object operator()(nb::args args, nb::kwargs kwargs); +}; + +extern void export_freeze(nb::module_ &); diff --git a/src/python/main.cpp b/src/python/main.cpp index 99712f3e..27cb02df 100644 --- a/src/python/main.cpp +++ b/src/python/main.cpp @@ -8,7 +8,7 @@ BSD-style license that can be found in the LICENSE.txt file. */ -#define NB_INTRUSIVE_EXPORT NB_EXPORT +#define NB_INTRUSIVE_EXPORT NB_IMPORT #include #include @@ -22,6 +22,7 @@ #include "cuda.h" #include "reduce.h" #include "eval.h" +#include "freeze.h" #include "iter.h" #include "init.h" #include "memop.h" @@ -106,6 +107,9 @@ NB_MODULE(_drjit_ext, m_) { .value("ScatterReduceLocal", JitFlag::ScatterReduceLocal, doc_JitFlag_ScatterReduceLocal) .value("SymbolicConditionals", JitFlag::SymbolicConditionals, doc_JitFlag_SymbolicConditionals) .value("SymbolicScope", JitFlag::SymbolicScope, doc_JitFlag_SymbolicScope) + .value("KernelFreezing", JitFlag::KernelFreezing, doc_JitFlag_KernelFreezing) + .value("FreezingScope", JitFlag::FreezingScope, doc_JitFlag_FreezingScope) + .value("EnableObjectTraversal", JitFlag::EnableObjectTraversal, doc_JitFlag_EnableObjectTraversal) .value("Default", JitFlag::Default, doc_JitFlag_Default) // Deprecated aliases @@ -235,6 +239,7 @@ NB_MODULE(_drjit_ext, m_) { export_iter(detail); export_reduce(m); export_eval(m); + export_freeze(m); export_memop(m); export_slice(m); export_dlpack(m); diff --git a/src/python/texture.h b/src/python/texture.h index e8d00985..4068ec22 100644 --- a/src/python/texture.h +++ b/src/python/texture.h @@ -173,6 +173,8 @@ void bind_texture(nb::module_ &m, const char *name) { #undef def_tex_eval_cubic_helper tex.attr("IsTexture") = true; + + drjit::bind_traverse(tex); } template diff --git a/src/python/tracker.cpp b/src/python/tracker.cpp index ca68abf5..6c048ce3 100644 --- a/src/python/tracker.cpp +++ b/src/python/tracker.cpp @@ -191,8 +191,8 @@ struct VariableTracker::Context { check_size(check_size), index_offset(0) { } // Internal API for type-erased traversal - uint64_t _traverse_write(uint64_t idx); - void _traverse_read(uint64_t index); + uint64_t _traverse_write(uint64_t idx, const char *, const char *); + void _traverse_read(uint64_t index, const char *, const char *); }; // Temporarily push a value onto the stack @@ -574,7 +574,8 @@ bool VariableTracker::Impl::traverse(Context &ctx, nb::handle h) { return changed; } -uint64_t VariableTracker::Context::_traverse_write(uint64_t idx) { +uint64_t VariableTracker::Context::_traverse_write(uint64_t idx, const char *, + const char *) { if (!idx) return 0; if (index_offset >= indices.size()) @@ -611,7 +612,7 @@ uint64_t VariableTracker::Context::_traverse_write(uint64_t idx) { return idx_new; } -void VariableTracker::Context::_traverse_read(uint64_t index) { +void VariableTracker::Context::_traverse_read(uint64_t index, const char *, const char *) { if (!index) return; indices.push_back(ad_var_inc_ref(index)); diff --git a/tests/call_ext.cpp b/tests/call_ext.cpp index 90dd869c..8b5d85de 100644 --- a/tests/call_ext.cpp +++ b/tests/call_ext.cpp @@ -6,6 +6,7 @@ #include #include #include +#include namespace nb = nanobind; namespace dr = drjit; @@ -13,29 +14,25 @@ namespace dr = drjit; using namespace nb::literals; template -struct Sampler { +struct Sampler : dr::TraversableBase { + Sampler() : rng(1) {} Sampler(size_t size) : rng(size) { } T next() { return rng.next_float32(); } - void traverse_1_cb_ro(void *payload, void (*fn)(void *, uint64_t)) const { - traverse_1_fn_ro(rng, payload, fn); - } - - void traverse_1_cb_rw(void *payload, uint64_t (*fn)(void *, uint64_t)) { - traverse_1_fn_rw(rng, payload, fn); - } - dr::PCG32> rng; + + DR_TRAVERSE_CB(dr::TraversableBase, rng); }; -template struct Base : nb::intrusive_base { +template struct Base : drjit::TraversableBase { using Mask = dr::mask_t; using UInt32 = dr::uint32_array_t; virtual std::pair f(Float x, Float y) = 0; virtual std::pair f_masked(const std::pair &xy, Mask active) = 0; virtual Float g(Float, Mask) = 0; + virtual Float h(Float) = 0; virtual Float nested(Float x, UInt32 s) = 0; virtual void dummy() = 0; virtual float scalar_getter() = 0; @@ -50,10 +47,12 @@ template struct Base : nb::intrusive_base { Base() { if constexpr (dr::is_jit_v) - jit_registry_put("", "Base", this); + drjit::registry_put("", "Base", this); } virtual ~Base() { jit_registry_remove(this); } + + DR_TRAVERSE_CB(drjit::TraversableBase) }; template struct A : Base { @@ -74,6 +73,10 @@ template struct A : Base { return value; } + virtual Float h(Float x) override{ + return value + x; + } + virtual Float nested(Float x, UInt32 /*s*/) override { return x + dr::gather(value, UInt32(0)); } @@ -112,6 +115,8 @@ template struct A : Base { uint32_t scalar_property; Float value, extra_value; Float opaque = dr::opaque(1.f); + + DR_TRAVERSE_CB(Base, value, opaque) }; template struct B : Base { @@ -132,6 +137,10 @@ template struct B : Base { return value*x; } + virtual Float h(Float x) override{ + return value - x; + } + virtual Float nested(Float x, UInt32 s) override { using BaseArray = dr::replace_value_t*>; BaseArray self = dr::reinterpret_array(s); @@ -160,6 +169,8 @@ template struct B : Base { Float value; Float opaque = dr::opaque(2.f); + + DR_TRAVERSE_CB(Base, value, opaque) }; DRJIT_CALL_TEMPLATE_BEGIN(Base) @@ -167,6 +178,7 @@ DRJIT_CALL_TEMPLATE_BEGIN(Base) DRJIT_CALL_METHOD(f_masked) DRJIT_CALL_METHOD(dummy) DRJIT_CALL_METHOD(g) + DRJIT_CALL_METHOD(h) DRJIT_CALL_METHOD(nested) DRJIT_CALL_METHOD(sample) DRJIT_CALL_METHOD(gather_packet) @@ -198,20 +210,22 @@ void bind(nb::module_ &m) { using Sampler = ::Sampler; auto sampler = nb::class_(m, "Sampler") + .def(nb::init<>()) .def(nb::init()) .def("next", &Sampler::next) .def_rw("rng", &Sampler::rng); bind_traverse(sampler); - nb::class_(m, "Base") + auto base_cls = nb::class_(m, "Base") .def("f", &BaseT::f) .def("f_masked", &BaseT::f_masked) .def("g", &BaseT::g) .def("nested", &BaseT::nested) .def("sample", &BaseT::sample); + bind_traverse(base_cls); - nb::class_(m, "A") + auto a_cls = nb::class_(m, "A") .def(nb::init<>()) .def("a_get_property", &AT::a_get_property) .def("a_gather_extra_value", &AT::a_gather_extra_value) @@ -219,11 +233,13 @@ void bind(nb::module_ &m) { .def_rw("value", &AT::value) .def_rw("extra_value", &AT::extra_value) .def_rw("scalar_property", &AT::scalar_property); + bind_traverse(a_cls); - nb::class_(m, "B") + auto b_cls = nb::class_(m, "B") .def(nb::init<>()) .def_rw("opaque", &BT::opaque) .def_rw("value", &BT::value); + bind_traverse(b_cls); using BaseArray = dr::DiffArray; m.def("dispatch_f", [](BaseArray &self, Float a, Float b) { @@ -243,6 +259,7 @@ void bind(nb::module_ &m) { .def("g", [](BaseArray &self, Float x, Mask m) { return self->g(x, m); }, "x"_a, "mask"_a = true) + .def("h", [](BaseArray &self, Float x) { return self->h(x); }, "x"_a) .def("nested", [](BaseArray &self, Float x, UInt32 s) { return self->nested(x, s); }, "x"_a, "s"_a) diff --git a/tests/custom_type_ext.cpp b/tests/custom_type_ext.cpp index 50c936ee..7bff0877 100644 --- a/tests/custom_type_ext.cpp +++ b/tests/custom_type_ext.cpp @@ -1,6 +1,10 @@ -#include #include #include +#include +#include +#include +#include +#include namespace nb = nanobind; namespace dr = drjit; @@ -42,6 +46,65 @@ struct CustomHolder { Value m_value; }; +class Object : public drjit::TraversableBase { + DR_TRAVERSE_CB(drjit::TraversableBase); +}; + +template +class CustomBase : public Object{ + Value m_base_value; + +public: + CustomBase(const Value &base_value) : Object(), m_base_value(base_value) {} + + Value &base_value() { return m_base_value; } + virtual Value &value() = 0; + + DR_TRAVERSE_CB(Object, m_base_value); +}; + +template +class PyCustomBase : public CustomBase{ +public: + using Base = CustomBase; + NB_TRAMPOLINE(Base, 1); + + PyCustomBase(const Value &base_value) : Base(base_value) {} + + Value &value() override { NB_OVERRIDE_PURE(value); } + + DR_TRAMPOLINE_TRAVERSE_CB(Base); +}; + +template +class CustomA: public CustomBase{ +public: + using Base = CustomBase; + + CustomA(const Value &value, const Value &base_value) : Base(base_value), m_value(value) {} + + Value &value() override { return m_value; } + +private: + Value m_value; + + DR_TRAVERSE_CB(Base, m_value); +}; + +template +class Nested: Object{ + using Base = Object; + + std::vector, size_t>> m_nested; + +public: + Nested(nb::ref a, nb::ref b) { + m_nested.push_back(std::make_pair(a, 0)); + m_nested.push_back(std::make_pair(b, 1)); + } + + DR_TRAVERSE_CB(Base, m_nested); +}; template void bind(nb::module_ &m) { dr::ArrayBinding b; @@ -64,12 +127,50 @@ template void bind(nb::module_ &m) { .def(nb::init()) .def("value", &CustomFloatHolder::value, nanobind::rv_policy::reference); + using CustomBase = CustomBase; + using PyCustomBase = PyCustomBase; + using CustomA = CustomA; + using Nested = Nested; + + auto object = nb::class_( + m, "Object", + nb::intrusive_ptr( + [](Object *o, PyObject *po) noexcept { o->set_self_py(po); })); + + auto base = + nb::class_(m, "CustomBase") + .def(nb::init()) + .def("value", nb::overload_cast<>(&CustomBase::value)) + .def("base_value", nb::overload_cast<>(&CustomBase::base_value)); + + drjit::bind_traverse(base); + + auto a = nb::class_(m, "CustomA") + .def(nb::init()); + + drjit::bind_traverse(a); + + auto nested = nb::class_(m, "Nested") + .def(nb::init, nb::ref>()); + + drjit::bind_traverse(nested); + m.def("cpp_make_opaque", [](CustomFloatHolder &holder) { dr::make_opaque(holder); } ); } NB_MODULE(custom_type_ext, m) { + nb::intrusive_init( + [](PyObject *o) noexcept { + nb::gil_scoped_acquire guard; + Py_INCREF(o); + }, + [](PyObject *o) noexcept { + nb::gil_scoped_acquire guard; + Py_DECREF(o); + }); + #if defined(DRJIT_ENABLE_LLVM) nb::module_ llvm = m.def_submodule("llvm"); bind(llvm); diff --git a/tests/test_custom_type_ext.py b/tests/test_custom_type_ext.py index 90c9b7a2..ad97bb3e 100644 --- a/tests/test_custom_type_ext.py +++ b/tests/test_custom_type_ext.py @@ -1,7 +1,6 @@ import drjit as dr import pytest - def get_pkg(t): with dr.detail.scoped_rtld_deepbind(): m = pytest.importorskip("custom_type_ext") @@ -69,3 +68,96 @@ def test03_cpp_make_opaque(t): pkg.cpp_make_opaque(holder) assert holder.value().state == dr.VarState.Evaluated + + +@pytest.test_arrays("float32,-diff,shape=(*),jit") +def test04_traverse_opaque(t): + """ + Tests that it is possible to traverse an opaque C++ object. + """ + pkg = get_pkg(t) + Float = t + + value = dr.arange(Float, 10) + base_value = dr.arange(Float, 10) + + a = pkg.CustomA(value, base_value) + assert dr.detail.collect_indices(a) == [base_value.index, value.index] + + +@pytest.test_arrays("float32,-diff,shape=(*),jit") +def test05_traverse_py(t): + """ + Tests the implementation of ``traverse_py_cb_ro``, which is used to traverse + python objects in trampoline classes. + """ + Float = t + + v = dr.arange(Float, 10) + + class PyClass: + def __init__(self, v) -> None: + self.v = v + + c = PyClass(v) + + result = [] + + def callback(index, domain, variant): + result.append(index) + + dr.detail.traverse_py_cb_ro(c, callback) + + assert result == [v.index] + + +@pytest.test_arrays("float32,-diff,shape=(*),jit") +def test06_trampoline_traversal(t): + """ + Tests that classes inheriting from trampoline classes are traversed + automatically. + """ + pkg = get_pkg(t) + Float = t + + value = dr.opaque(Float, 0, 3) + base_value = dr.opaque(Float, 1, 3) + + class B(pkg.CustomBase): + def __init__(self, value, base_value) -> None: + super().__init__(base_value) + self._value = value + + def value(self): + return self._value + + b = B(value, base_value) + + assert dr.detail.collect_indices(b) == [base_value.index, value.index] + +@pytest.test_arrays("float32,-diff,shape=(*),jit") +def test07_nested_traversal(t): + """ + Test traversal of nested objects, and more specifically the traversal of + ``std::vector, size_t>>`` members. + """ + pkg = get_pkg(t) + Float = t + + value = dr.arange(Float, 10) + 0 + base_value = dr.arange(Float, 10) + 1 + + a = pkg.CustomA(value, base_value) + + value = dr.arange(Float, 10) + 2 + base_value = dr.arange(Float, 10) + 3 + + b = pkg.CustomA(value, base_value) + + nested = pkg.Nested(a, b) + + indices_a = dr.detail.collect_indices(a) + indices_b = dr.detail.collect_indices(b) + indices_nested = dr.detail.collect_indices(nested) + + assert indices_nested == indices_a + indices_b diff --git a/tests/test_freeze.py b/tests/test_freeze.py new file mode 100644 index 00000000..c9783309 --- /dev/null +++ b/tests/test_freeze.py @@ -0,0 +1,2865 @@ +import drjit as dr +import pytest +from math import ceil +from dataclasses import dataclass +import sys + +def get_single_entry(x): + tp = type(x) + result = x + shape = dr.shape(x) + if len(shape) == 2: + result = result[shape[0] - 1] + if len(shape) == 3: + result = result[shape[0] - 1][shape[1] - 1] + return result + + +@pytest.test_arrays("uint32, jit, shape=(*)") +def test01_basic(t): + """ + Tests a very basic frozen function, adding two integers x, y. + """ + + @dr.freeze + def func(x, y): + return x + y + + i0 = t(0, 1, 2) + i1 = t(2, 1, 0) + + o0 = func(i0, i1) + assert dr.all(t(2, 2, 2) == o0) + + i0 = t(1, 2, 3) + i1 = t(3, 2, 1) + + o0 = func(i0, i1) + assert dr.all(t(4, 4, 4) == o0) + + +@pytest.test_arrays("uint32, jit, shape=(*)") +def test02_flush_kernel_cache(t): + """ + Tests that flushing the kernel between recording and replaying a frozen + function causes the function to be re-traced. + """ + + def func(x, y): + return x + y + + frozen = dr.freeze(func) + + x = t(0, 1, 2) + y = t(2, 1, 0) + + res = frozen(x, y) + ref = func(x, y) + assert dr.all(res == ref) + + dr.flush_kernel_cache() + + x = t(1, 2, 3) + y = t(3, 2, 1) + + # Flushing the kernel cache should force a re-trace + res = frozen(x, y) + ref = func(x, y) + assert dr.all(res == ref) + assert frozen.n_recordings == 2 + + +@pytest.test_arrays("uint32, jit, shape=(*)") +def test03_output_tuple(t): + """ + Tests that returning tuples from frozen functions is possible. + """ + + @dr.freeze + def func(x, y): + return (x + y, x * y) + + i0 = t(0, 1, 2) + i1 = t(2, 1, 0) + + (o0, o1) = func(i0, i1) + assert dr.all(t(2, 2, 2) == o0) + assert dr.all(t(0, 1, 0) == o1) + + i0 = t(1, 2, 3) + i1 = t(3, 2, 1) + + (o0, o1) = func(i0, i1) + assert dr.all(t(4, 4, 4) == o0) + assert dr.all(t(3, 4, 3) == o1) + + +@pytest.test_arrays("uint32, jit, shape=(*)") +def test04_output_list(t): + """ + Tests that returning lists from forzen functions is possible. + """ + + @dr.freeze + def func(x, y): + return [x + y, x * y] + + i0 = t(0, 1, 2) + i1 = t(2, 1, 0) + + [o0, o1] = func(i0, i1) + assert dr.all(t(2, 2, 2) == o0) + assert dr.all(t(0, 1, 0) == o1) + + i0 = t(1, 2, 3) + i1 = t(3, 2, 1) + + [o0, o1] = func(i0, i1) + assert dr.all(t(4, 4, 4) == o0) + assert dr.all(t(3, 4, 3) == o1) + + +@pytest.test_arrays("uint32, jit, shape=(*)") +def test05_output_dict(t): + """ + Tests that returning dictionaries from forzen functions is possible. + """ + + @dr.freeze + def func(x, y): + return {"add": x + y, "mul": x * y} + + i0 = t(0, 1, 2) + i1 = t(2, 1, 0) + + o = func(i0, i1) + o0 = o["add"] + o1 = o["mul"] + assert dr.all(t(2, 2, 2) == o0) + assert dr.all(t(0, 1, 0) == o1) + + i0 = t(1, 2, 3) + i1 = t(3, 2, 1) + + o = func(i0, i1) + o0 = o["add"] + o1 = o["mul"] + assert dr.all(t(4, 4, 4) == o0) + assert dr.all(t(3, 4, 3) == o1) + + +@pytest.test_arrays("uint32, jit, shape=(*)") +def test06_nested_tuple(t): + """ + Tests that returning nested tuples from forzen functions is possible. + """ + + @dr.freeze + def func(x): + return (x + 1, x + 2, (x + 3, x + 4)) + + i0 = t(0, 1, 2) + + (o0, o1, (o2, o3)) = func(i0) + assert dr.all(t(1, 2, 3) == o0) + assert dr.all(t(2, 3, 4) == o1) + assert dr.all(t(3, 4, 5) == o2) + assert dr.all(t(4, 5, 6) == o3) + + i0 = t(1, 2, 3) + + +@pytest.test_arrays("uint32, jit, shape=(*)") +def test07_drjit_struct(t): + """ + Tests that returning custom classes, annotated with ``DRJIT_STRUCT`` from + forzen functions is possible. + """ + + class Point: + x: t + y: t + DRJIT_STRUCT = {"x": t, "y": t} + + @dr.freeze + def func(x): + p = Point() + p.x = x + 1 + p.y = x + 2 + return p + + i0 = t(0, 1, 2) + + o = func(i0) + o0 = o.x + o1 = o.y + assert dr.all(t(1, 2, 3) == o0) + assert dr.all(t(2, 3, 4) == o1) + + i0 = t(1, 2, 3) + + o = func(i0) + o0 = o.x + o1 = o.y + assert dr.all(t(2, 3, 4) == o0) + assert dr.all(t(3, 4, 5) == o1) + + +@pytest.test_arrays("uint32, jit, shape=(*)") +def test08_dataclass(t): + """ + Tests that returning custom dataclasses from forzen functions is possible. + """ + + @dataclass + class Point: + x: t + y: t + + @dr.freeze + def func(x): + p = Point(x + 1, x + 2) + return p + + i0 = t(0, 1, 2) + + o = func(i0) + o0 = o.x + o1 = o.y + assert dr.all(t(1, 2, 3) == o0) + assert dr.all(t(2, 3, 4) == o1) + + i0 = t(1, 2, 3) + + o = func(i0) + o0 = o.x + o1 = o.y + assert dr.all(t(2, 3, 4) == o0) + assert dr.all(t(3, 4, 5) == o1) + + +@pytest.test_arrays("float32, jit, shape=(*)") +def test09_traverse_cb(t): + """ + Tests that passing opaque C++ objects to frozen functions is possible. + It should not be possible to return these from frozen functions. + """ + pkg = get_pkg(t) + Sampler = pkg.Sampler + + def func(sampler): + return sampler.next() + + frozen = dr.freeze(func) + + sampler_frozen = Sampler(10) + sampler_func = Sampler(10) + + result1_frozen = frozen(sampler_frozen) + result1_func = func(sampler_func) + assert dr.allclose(result1_frozen, result1_func) + + sampler_frozen = Sampler(10) + sampler_func = Sampler(10) + + result2_frozen = frozen(sampler_frozen) + result2_func = func(sampler_func) + assert dr.allclose(result2_frozen, result2_func) + + assert frozen.n_recordings == 1 + + result3_frozen = frozen(sampler_frozen) + result3_func = func(sampler_func) + assert dr.allclose(result3_func, result3_frozen) + + assert frozen.n_recordings == 1 + + +@pytest.test_arrays("uint32, jit, shape=(*)") +def test10_scatter(t): + """ + Tests that it is possible to scatter to the input of a frozen function, + while leaving variables depending on the input the same (scattering problem). + """ + + @dr.freeze + def func(x): + dr.scatter(x, 0, dr.arange(t, 3)) + + x = t(0, 1, 2) + func(x) + + x = t(0, 1, 2) + y = x + 1 + z = x + w = t(x) + + func(x) + + assert dr.all(t(0, 0, 0) == x) + assert dr.all(t(1, 2, 3) == y) + assert dr.all(t(0, 0, 0) == z) + assert dr.all(t(0, 1, 2) == w) + + +@pytest.test_arrays("float32, jit, is_diff, shape=(*)") +def test11_optimization(t): + """ + Implements a simple gradient descent optimization of a variable in a + frozen function. This verifies that gradient descent kernels are evaluated + correctly. + """ + + @dr.freeze + def func(state, ref): + for k, x in state.items(): + dr.enable_grad(x) + loss = dr.mean(dr.square(x - ref)) + + dr.backward(loss) + + grad = dr.grad(x) + dr.disable_grad(x) + state[k] = x - grad + + state = {"x": t(0, 0, 0, 0)} + + ref = t(1, 1, 1, 1) + func(state, ref) + assert dr.allclose(t(0.5, 0.5, 0.5, 0.5), state["x"]) + + state = {"x": t(0, 0, 0, 0)} + ref = t(1, 1, 1, 1) + func(state, ref) + + assert dr.allclose(t(0.5, 0.5, 0.5, 0.5), state["x"]) + + +@pytest.test_arrays("uint32, jit, shape=(*)") +def test12_resized(t): + """ + Tests that it is possible to call a frozen function with inputs of different + size compared to the recording without having to re-trace the function. + """ + + @dr.freeze + def func(x, y): + return x + y + + i0 = t(0, 1, 2) + i1 = t(2, 1, 0) + + o0 = func(i0, i1) + assert dr.all(t(2, 2, 2) == o0) + + i0 = dr.arange(t, 64) + dr.opaque(t, 0) + i1 = dr.arange(t, 64) + dr.opaque(t, 0) + r0 = i0 + i1 + dr.eval(i0, i1, r0) + + o0 = func(i0, i1) + assert dr.all(r0 == o0) + assert func.n_recordings == 1 + + +@pytest.test_arrays("uint32, jit, shape=(*)") +def test13_changed_input_dict(t): + """ + Test that it is possible to pass a dictionary to a frozen function, that is + inserting the result at a new key in said dictionary. This ensures that the + input is evaluated correctly, and the dictionary is back-assigned to the input. + """ + + @dr.freeze + def func(d: dict): + d["y"] = d["x"] + 1 + + x = t(0, 1, 2) + d = {"x": x} + + func(d) + assert dr.all(t(1, 2, 3) == d["y"]) + + x = t(1, 2, 3) + d = {"x": x} + + func(d) + assert dr.all(t(2, 3, 4) == d["y"]) + + +@pytest.test_arrays("uint32, jit, shape=(*)") +def test14_changed_input_dataclass(t): + """ + Tests that it is possible to asing to the input of a dataclass inside a + frozen function. This also relies on correct back-assignment of the input. + """ + + @dataclass + class Point: + x: t + + @dr.freeze + def func(p: Point): + p.x = p.x + 1 + + p = Point(x=t(0, 1, 2)) + + func(p) + assert dr.all(t(1, 2, 3) == p.x) + + p = Point(x=t(1, 2, 3)) + + func(p) + assert dr.all(t(2, 3, 4) == p.x) + + +@pytest.test_arrays("uint32, jit, shape=(*)") +def test15_kwargs(t): + """ + Tests that it is possible to pass keyword arguments to a frozen function + that modifies them. + """ + + @dr.freeze + def func(x=t(0, 1, 2)): + return x + 1 + + y = func(x=t(0, 1, 2)) + assert dr.all(t(1, 2, 3) == y) + + y = func(x=t(1, 2, 3)) + assert dr.all(t(2, 3, 4) == y) + + +@pytest.test_arrays("uint32, jit, shape=(*)") +def test16_opaque(t): + """ + Tests that changing from an opaque (1-sized array) to an array of size + larger than 1 causes the funcion to be re-traced. This is necessary, because + different kernels are compiled for the two cases. + """ + + @dr.freeze + def func(x, y): + return x + y + + x = t(0, 1, 2) + dr.set_label(x, "x") + y = dr.opaque(t, 1) + dr.set_label(y, "y") + z = func(x, y) + assert dr.all(t(1, 2, 3) == z) + + x = t(1, 2, 3) + y = t(1, 2, 3) + z = func(x, y) + assert dr.all(t(2, 4, 6) == z) + + assert func.n_recordings == 2 + + +@pytest.test_arrays("float32, jit, -is_diff, shape=(*)") +def test17_performance(t): + """ + Tests the performance of a frozen function versus a non-frozen function. + """ + import time + + n = 1024 + n_iter = 1_000 + n_iter_warmeup = 10 + + def func(x, y): + z = 0.5 + result = dr.fma(dr.square(x), y, z) + result = dr.sqrt(dr.abs(result) + dr.power(result, 10)) + result = dr.log(1 + result) + return result + + frozen = dr.freeze(func) + + for name, fn in [("normal", func), ("frozen", frozen)]: + x = dr.arange(t, n) # + dr.opaque(t, i) + y = dr.arange(t, n) # + dr.opaque(t, i) + dr.eval(x, y) + for i in range(n_iter + n_iter_warmeup): + if i == n_iter_warmeup: + t0 = time.time() + + result = fn(x, y) + + dr.eval(result) + + dr.sync_thread() + elapsed = time.time() - t0 + + +@pytest.test_arrays("uint32, jit, shape=(*)") +def test18_aliasing(t): + """ + Tests that changing the inputs from being the same variable to two different + variables causes the function to be re-traced. + """ + + @dr.freeze + def func(x, y): + return x + y + + x = t(0, 1, 2) + y = x + z = func(x, y) + assert dr.all(t(0, 2, 4) == z) + + x = t(1, 2, 3) + y = x + z = func(x, y) + assert dr.all(t(2, 4, 6) == z) + + x = t(1, 2, 3) + y = t(2, 3, 4) + z = func(x, y) + assert dr.all(t(3, 5, 7) == z) + assert func.n_recordings == 2 + + +@pytest.test_arrays("uint32, jit, shape=(*)") +def test19_non_jit_types(t): + """ + Tests that it is possible to pass non-jit types such as integers to frozen + functions. + """ + + def func(x, y): + return x + y + + frozen = dr.freeze(func) + + for i in range(3): + x = t(1, 2, 3) + y = i + + res = frozen(x, y) + ref = func(x, y) + assert dr.all(res == ref) + + assert frozen.n_recordings == 3 + + +@pytest.test_arrays("uint32, jit, shape=(*)") +def test20_literal(t): + """ + Test that drjit literals, passed to frozen functions do not cause the + function to be re-traced if they change. This is enabled by making the input + opaque. + """ + + @dr.freeze + def func(x, y): + z = x + y + w = t(1) + return z, w + + # Literals + x = t(0, 1, 2) + dr.set_label(x, "x") + y = t(1) + dr.set_label(y, "y") + z, w = func(x, y) + assert dr.all(z == t(1, 2, 3)) + assert dr.all(w == t(1)) + + x = t(0, 1, 2) + y = t(1) + z, w = func(x, y) + assert dr.all(z == t(1, 2, 3)) + assert dr.all(w == t(1)) + + assert func.n_recordings == 1 + + x = t(0, 1, 2) + y = t(2) + z = func(x, y) + assert func.n_recordings == 1 + + +@pytest.test_arrays("uint32, jit, shape=(*)") +def test21_pointers(t): + """ + Test that it is possible to gather from a same-sized variable. This tests + the kernel size inference algorithm as well as having two kernels in a + frozen function. + """ + UInt32 = dr.uint32_array_t(t) + + @dr.freeze + def func(x): + idx = dr.arange(UInt32, 0, dr.width(x), 3) + + return dr.gather(t, x, idx) + + y = func(t(0, 1, 2, 3, 4, 5, 6)) + + +@pytest.test_arrays("uint32, jit, shape=(*)") +def test22_gather_memcpy(t): + """ + The gather operation might be elided in favor of a memcpy + if the index is a literal of size 1. + The source of the memcpy is however not known to the recording + mechansim as it might index into the source array. + """ + + def func(x, idx: int): + idx = t(idx) + return dr.gather(t, x, idx) + + frozen = dr.freeze(func) + + for i in range(3): + x = dr.arange(t, i, 3 + i) + dr.make_opaque(x) + ref = func(x, 2) + result = frozen(x, 2) + + assert dr.all(ref == result) + + assert frozen.n_recordings == 1 + + +def get_pkg(t): + with dr.detail.scoped_rtld_deepbind(): + m = pytest.importorskip("call_ext") + backend = dr.backend_v(t) + if backend == dr.JitBackend.LLVM: + return m.llvm + elif backend == dr.JitBackend.CUDA: + return m.cuda + + +@pytest.mark.parametrize("symbolic", [True]) +@pytest.test_arrays("float32, jit, -is_diff, shape=(*)") +def test23_vcall(t, symbolic): + """ + Tests a basic symbolic vcall being called inside a frozen function. + """ + pkg = get_pkg(t) + + A, B, Base, BasePtr = pkg.A, pkg.B, pkg.Base, pkg.BasePtr + Mask = dr.mask_t(t) + a, b = A(), B() + + c = BasePtr(a, a, None, a, a) + + xi = t(1, 2, 8, 3, 4) + yi = t(5, 6, 8, 7, 8) + + @dr.freeze + def func(c, xi, yi): + return c.f(xi, yi) + + with dr.scoped_set_flag(dr.JitFlag.SymbolicCalls, symbolic): + xo, yo = func(c, xi, yi) + + assert dr.all(xo == t(10, 12, 0, 14, 16)) + assert dr.all(yo == t(-1, -2, 0, -3, -4)) + + c = BasePtr(a, a, None, b, b) + + with dr.scoped_set_flag(dr.JitFlag.SymbolicCalls, symbolic): + xo, yo = func(c, xi, yi) + + assert func.n_recordings == 1 + + assert dr.all(xo == t(10, 12, 0, 21, 24)) + assert dr.all(yo == t(-1, -2, 0, 3, 4)) + + +@pytest.mark.parametrize("symbolic", [True]) +@pytest.mark.parametrize("optimize", [True, False]) +@pytest.mark.parametrize("opaque", [True, False]) +@pytest.test_arrays("float32, -is_diff, jit, shape=(*)") +def test24_vcall_optimize(t, symbolic, optimize, opaque): + """ + Test a basic vcall being called inside a frozen function, with the + "OptimizeCalls" flag either being set or not set. As well as opaque and + non-opaque inputs. + """ + pkg = get_pkg(t) + + A, B, Base, BasePtr = pkg.A, pkg.B, pkg.Base, pkg.BasePtr + Mask = dr.mask_t(t) + a, b = B(), B() + + dr.set_label(A.opaque, "A.opaque") + dr.set_label(B.opaque, "B.opaque") + + a.value = t(2) + b.value = t(3) + + if opaque: + dr.make_opaque(a.value, b.value) + + c = BasePtr(a, a, None, a, a) + dr.set_label(c, "c") + + x = t(1, 2, 8, 3, 4) + dr.set_label(x, "x") + + def func(c, xi): + return c.g(xi) + + frozen = dr.freeze(func) + + with dr.scoped_set_flag(dr.JitFlag.SymbolicCalls, symbolic): + with dr.scoped_set_flag(dr.JitFlag.OptimizeCalls, optimize): + xo = frozen(c, x) + + assert dr.all(xo == func(c, x)) + + a.value = t(3) + b.value = t(2) + + if opaque: + dr.make_opaque(a.value, b.value) + + c = BasePtr(a, a, None, b, b) + dr.set_label(c, "c") + + with dr.scoped_set_flag(dr.JitFlag.SymbolicCalls, symbolic): + with dr.scoped_set_flag(dr.JitFlag.OptimizeCalls, optimize): + xo = frozen(c, x) + + assert frozen.n_recordings == 1 + assert dr.all(xo == func(c, x)) + + +@pytest.mark.parametrize("symbolic", [True]) +@pytest.mark.parametrize("optimize", [True, False]) +@pytest.mark.parametrize("opaque", [True, False]) +@pytest.test_arrays("float32, -is_diff, jit, shape=(*)") +def test25_multiple_vcalls(t, symbolic, optimize, opaque): + """ + Test calling multiple vcalls in a frozen function, where the result of the + first is used as the input to the second. + """ + pkg = get_pkg(t) + + A, B, Base, BasePtr = pkg.A, pkg.B, pkg.Base, pkg.BasePtr + Mask = dr.mask_t(t) + a, b = B(), B() + + dr.set_label(A.opaque, "A.opaque") + dr.set_label(B.opaque, "B.opaque") + + a.value = t(2) + b.value = t(3) + + if opaque: + dr.make_opaque(a.value, b.value) + + c = BasePtr(a, a, None, a, a) + dr.set_label(c, "c") + + x = t(1, 2, 8, 3, 4) + dr.set_label(x, "x") + + def func(c, xi): + x = c.h(xi) + dr.make_opaque(x) + return c.g(x) + + frozen = dr.freeze(func) + + with dr.scoped_set_flag(dr.JitFlag.SymbolicCalls, symbolic): + with dr.scoped_set_flag(dr.JitFlag.OptimizeCalls, optimize): + xo = frozen(c, x) + + assert dr.all(xo == func(c, x)) + + a.value = t(3) + b.value = t(2) + + if opaque: + dr.make_opaque(a.value, b.value) + + c = BasePtr(a, a, None, b, b) + dr.set_label(c, "c") + + with dr.scoped_set_flag(dr.JitFlag.SymbolicCalls, symbolic): + with dr.scoped_set_flag(dr.JitFlag.OptimizeCalls, optimize): + xo = frozen(c, x) + + assert frozen.n_recordings == 1 + + assert dr.all(xo == func(c, x)) + + +@pytest.test_arrays("float32, jit, shape=(*)") +def test26_freeze(t): + """ + Test freezing a simple frozen function. + """ + UInt32 = dr.uint32_array_t(t) + Float = dr.float32_array_t(t) + + @dr.freeze + def my_kernel(x): + x_int = UInt32(x) + result = x * x + result_int = UInt32(result) + + return result, x_int, result_int + + for i in range(3): + x = Float([1.0, 2.0, 3.0]) + dr.opaque(Float, i) + + y, x_int, y_int = my_kernel(x) + dr.schedule(y, x_int, y_int) + assert dr.allclose(y, dr.square(x)) + assert dr.allclose(x_int, UInt32(x)) + assert dr.allclose(y_int, UInt32(y)) + + +@pytest.mark.parametrize("freeze_first", (True, False)) +@pytest.test_arrays("float32, jit, shape=(*)") +def test27_calling_frozen_from_frozen(t, freeze_first): + """ + Test calling a frozen function from within another frozen function. + The inner frozen function should behave as a normal function. + """ + mod = sys.modules[t.__module__] + Float = mod.Float32 + Array3f = mod.Array3f + n = 37 + x = dr.full(Float, 1.5, n) + dr.opaque(Float, 2) + y = dr.full(Float, 0.5, n) + dr.opaque(Float, 10) + dr.eval(x, y) + + @dr.freeze + def fun1(x): + return dr.square(x) + + @dr.freeze + def fun2(x, y): + return fun1(x) + fun1(y) + + # Calling a frozen function from a frozen function. + if freeze_first: + dr.eval(fun1(x)) + + result1 = fun2(x, y) + assert dr.allclose(result1, dr.square(x) + dr.square(y)) + + if not freeze_first: + # If the nested function hasn't been recorded yet, calling it + # while freezing the outer function shouldn't freeze it with + # those arguments. + # In other words, any freezing mechanism should be completely + # disabled while recording a frozen function. + # assert fun1.frozen.kernels is None + + # We should therefore be able to freeze `fun1` with a different + # type of argument, and both `fun1` and `fun2` should work fine. + result2 = fun1(Array3f(0.5, x, y)) + assert dr.allclose(result2, Array3f(0.5 * 0.5, dr.square(x), dr.square(y))) + + result3 = fun2(2 * x, 0.5 * y) + assert dr.allclose(result3, dr.square(2 * x) + dr.square(0.5 * y)) + + +@pytest.test_arrays("float32, jit, shape=(*)") +def test28_recorded_size(t): + """ + Tests that a frozen function, producing a variable with a constant size, + can be replayed and produces an output of the same size. + """ + mod = sys.modules[t.__module__] + UInt32 = mod.UInt32 + Float = mod.Float + + @dr.freeze + def fun(a): + x = t(dr.linspace(Float, -1, 1, 10)) + a + source = x + 2 * x + # source = get_single_entry(x + 2 * x) + index = dr.arange(UInt32, dr.width(source)) + active = index % UInt32(2) != 0 + + return dr.gather(Float, source, index, active) + + a = t(0.1) + res1 = fun(a) + res2 = fun(a) + res3 = fun(a) + + assert dr.allclose(res1, res2) + assert dr.allclose(res1, res3) + + +@pytest.test_arrays("float32, jit, shape=(*)") +def test29_with_gathers(t): + """ + Test gathering from an array at every second index in a frozen function. + """ + import numpy as np + + n = 20 + mod = sys.modules[t.__module__] + UInt32 = mod.UInt32 + + rng = np.random.default_rng(seed=1234) + shape = tuple(reversed(dr.shape(dr.zeros(t, n)))) + + def fun(x, idx): + active = idx % 2 != 0 + source = get_single_entry(x) + return dr.gather(type(source), source, idx, active=active) + + fun_frozen = dr.freeze(fun) + + # 1. Recording call + x1 = t(rng.uniform(low=-1, high=1, size=shape)) + idx1 = dr.arange(UInt32, n) + result1 = fun_frozen(x1, idx1) + assert dr.allclose(result1, fun(x1, idx1)) + + # 2. Different source as during recording + x2 = t(rng.uniform(low=-2, high=-1, size=shape)) + idx2 = idx1 + + result2 = fun_frozen(x2, idx2) + assert dr.allclose(result2, fun(x2, idx2)) + + x3 = x2 + idx3 = UInt32([i for i in reversed(range(n))]) + result3 = fun_frozen(x3, idx3) + assert dr.allclose(result3, fun(x3, idx3)) + + # 3. Same source as during recording + result4 = fun_frozen(x1, idx1) + assert dr.allclose(result4, result1) + + +@pytest.test_arrays("float32, jit, shape=(*)") +def test30_scatter_with_op(t): + """ + Tests scattering into the input of a frozen function. + """ + import numpy as np + + n = 16 + mod = sys.modules[t.__module__] + UInt32 = mod.UInt32 + + rng = np.random.default_rng(seed=1234) + + def func(x, idx): + active = idx % 2 != 0 + + result = x - 0.5 + dr.scatter(x, result, idx, active=active) + return result + + func_frozen = dr.freeze(func) + + # 1. Recording call + x1 = t(rng.uniform(low=-1, high=1, size=[n])) + x1_copy = t(x1) + x1_copy_copy = t(x1) + idx1 = dr.arange(UInt32, n) + + result1 = func_frozen(x1, idx1) + + # assert dr.allclose(x1, x1_copy) + assert dr.allclose(result1, func(x1_copy, idx1)) + + # 2. Different source as during recording + # TODO: problem: during trace, the actual x1 Python variable changes + # from index r2 to index r12 as a result of the `scatter`. + # But in subsequent launches, even if we successfully create a new + # output buffer equivalent to r12, it doesn't get assigned to `x2`. + x2 = t(rng.uniform(low=-2, high=-1, size=[n])) + x2_copy = t(x2) + idx2 = idx1 + + result2 = func_frozen(x2, idx2) + assert dr.allclose(result2, func(x2_copy, idx2)) + # assert dr.allclose(x2, x2_copy) + + x3 = x2 + x3_copy = t(x3) + idx3 = UInt32([i for i in reversed(range(n))]) + result3 = func_frozen(x3, idx3) + assert dr.allclose(result3, func(x3_copy, idx3)) + # assert dr.allclose(x3, x3_copy) + + # # 3. Same source as during recording + result4 = func_frozen(x1_copy_copy, idx1) + assert dr.allclose(result4, result1) + # # assert dr.allclose(x1_copy_copy, x1) + + +@pytest.test_arrays("float32, jit, shape=(*)") +def test31_with_gather_and_scatter(t): + """ + Tests a combination of scatters and gathers in a frozen function. + """ + + import numpy as np + + n = 20 + mod = sys.modules[t.__module__] + UInt32 = mod.UInt32 + + rng = np.random.default_rng(seed=1234) + shape = tuple(reversed(dr.shape(dr.zeros(t, n)))) + + def fun(x, idx): + active = idx % 2 != 0 + dest = get_single_entry(x) + + values = dr.gather(UInt32, idx, idx, active=active) + values = type(dest)(values) + dr.scatter(dest, values, idx, active=active) + return dest, values + + fun_frozen = dr.freeze(fun) + + # 1. Recording call + x1 = t(rng.uniform(low=-1, high=1, size=shape)) + x1_copy = t(x1) + x1_copy_copy = t(x1) + idx1 = dr.arange(UInt32, n) + + result1 = fun_frozen(x1, idx1) + assert dr.allclose(result1, fun(x1_copy, idx1)) + assert dr.allclose(x1, x1_copy) + + # 2. Different source as during recording + x2 = t(rng.uniform(low=-2, high=-1, size=shape)) + x2_copy = t(x2) + idx2 = idx1 + + result2 = fun_frozen(x2, idx2) + assert dr.allclose(result2, fun(x2_copy, idx2)) + assert dr.allclose(x2, x2_copy) + + x3 = x2 + x3_copy = t(x3) + idx3 = UInt32([i for i in reversed(range(n))]) + result3 = fun_frozen(x3, idx3) + assert dr.allclose(result3, fun(x3_copy, idx3)) + assert dr.allclose(x3, x3_copy) + + # 3. Same source as during recording + result4 = fun_frozen(x1_copy_copy, idx1) + assert dr.allclose(result4, result1) + assert dr.allclose(x1_copy_copy, x1) + + +@pytest.mark.parametrize("relative_size", ["<", "=", ">"]) +@pytest.test_arrays("float32, jit, shape=(*)") +def test32_gather_only_pointer_as_input(t, relative_size): + """ + Tests that it is possible to infer the launch size of kernels, if the width + of the resulting variable is a multiple/fraction of the variables from which + the result was gathered. + """ + mod = sys.modules[t.__module__] + Array3f = mod.Array3f + Float = mod.Float32 + UInt32 = mod.UInt32 + + import numpy as np + + rng = np.random.default_rng(seed=1234) + + if relative_size == "<": + + def fun(v): + idx = dr.arange(UInt32, 0, dr.width(v), 3) + return Array3f( + dr.gather(Float, v, idx), + dr.gather(Float, v, idx + 1), + dr.gather(Float, v, idx + 2), + ) + + elif relative_size == "=": + + def fun(v): + idx = dr.arange(UInt32, 0, dr.width(v)) // 2 + return Array3f( + dr.gather(Float, v, idx), + dr.gather(Float, v, idx + 1), + dr.gather(Float, v, idx + 2), + ) + + elif relative_size == ">": + + def fun(v): + max_width = dr.width(v) + idx = dr.arange(UInt32, 0, 5 * max_width) + # TODO(!): what can we do against this literal being baked into the kernel? + active = (idx + 2) < max_width + return Array3f( + dr.gather(Float, v, idx, active=active), + dr.gather(Float, v, idx + 1, active=active), + dr.gather(Float, v, idx + 2, active=active), + ) + + fun_freeze = dr.freeze(fun) + + def check_results(v, result): + size = v.size + if relative_size == "<": + expected = v.T + if relative_size == "=": + idx = np.arange(0, size) // 2 + expected = v.ravel() + expected = np.stack( + [ + expected[idx], + expected[idx + 1], + expected[idx + 2], + ], + axis=0, + ) + elif relative_size == ">": + idx = np.arange(0, 5 * size) + mask = (idx + 2) < size + expected = v.ravel() + expected = np.stack( + [ + np.where(mask, expected[(idx) % size], 0), + np.where(mask, expected[(idx + 1) % size], 0), + np.where(mask, expected[(idx + 2) % size], 0), + ], + axis=0, + ) + + assert np.allclose(result.numpy(), expected) + + # Note: Does not fail for n=1 + n = 7 + + for i in range(3): + v = rng.uniform(size=[n, 3]) + result = fun(Float(v.ravel())) + check_results(v, result) + + for i in range(10): + if i <= 5: + n_lanes = n + else: + n_lanes = n + i + + v = rng.uniform(size=[n_lanes, 3]) + result = fun_freeze(Float(v.ravel())) + + expected_width = { + "<": n_lanes, + "=": n_lanes * 3, + ">": n_lanes * 3 * 5, + }[relative_size] + + # if i == 0: + # assert len(fun_freeze.frozen.kernels) + # for kernel in fun_freeze.frozen.kernels.values(): + # assert kernel.original_input_size == n * 3 + # if relative_size == "<": + # assert kernel.original_launch_size == expected_width + # assert kernel.original_launch_size_ratio == (False, 3, True) + # elif relative_size == "=": + # assert kernel.original_launch_size == expected_width + # assert kernel.original_launch_size_ratio == (False, 1, True) + # else: + # assert kernel.original_launch_size == expected_width + # assert kernel.original_launch_size_ratio == (True, 5, True) + + assert dr.width(result) == expected_width + if relative_size == ">" and n_lanes != n: + pytest.xfail( + reason="The width() of the original input is baked into the kernel to compute the `active` mask during the first launch, so results are incorrect once the width changes." + ) + + check_results(v, result) + + +@pytest.test_arrays("float32, jit, shape=(*)") +def test33_multiple_kernels(t): + def fn(x: dr.ArrayBase, y: dr.ArrayBase, flag: bool): + # TODO: test with gathers and scatters, which is a really important use-case. + # TODO: test with launches of different sizes (including the auto-sizing logic) + # TODO: test with an intermediate output of literal type + # TODO: test multiple kernels that scatter_add to a newly allocated kernel in sequence. + + # First kernel uses only `x` + quantity = 0.5 if flag else -0.5 + intermediate1 = x + quantity + intermediate2 = x * quantity + dr.eval(intermediate1, intermediate2) + + # Second kernel uses `x`, `y` and one of the intermediate result + result = intermediate2 + y + + # The function returns some mix of outputs + return intermediate1, None, y, result + + n = 15 + x = dr.full(t, 1.5, n) + dr.opaque(t, 0.2) + y = dr.full(t, 0.5, n) + dr.opaque(t, 0.1) + dr.eval(x, y) + + ref_results = fn(x, y, flag=True) + dr.eval(ref_results) + + fn_frozen = dr.freeze(fn) + for _ in range(2): + results = fn_frozen(x, y, flag=True) + assert dr.allclose(results[0], ref_results[0]) + assert results[1] is None + assert dr.allclose(results[2], y) + assert dr.allclose(results[3], ref_results[3]) + + # TODO: + # We don't yet make a difference between check and no-check + + # for i in range(4): + # new_y = y + float(i) + # # Note: we did not enabled `check` mode, so changing this Python + # # value will not throw an exception. The new value has no influence + # # on the result even though without freezing, it would. + # # TODO: support "signature" detection and create separate frozen + # # function instances. + # new_flag = (i % 2) == 0 + # results = fn_frozen(x, new_y, flag=new_flag) + # assert dr.allclose(results[0], ref_results[0]) + # assert results[1] is None + # assert dr.allclose(results[2], new_y) + # assert dr.allclose(results[3], x * 0.5 + new_y) + + +@pytest.test_arrays("float32, jit, shape=(*)") +def test34_global_flag(t): + Float = t + + @dr.freeze + def my_fn(a, b, c=0.5): + return a + b + c + + # Recording + one = Float([1.0] * 9) + result1 = my_fn(one, one, c=0.1) + assert dr.allclose(result1, 2.1) + + # Can change the type of an input + result2 = my_fn(one, one, c=Float(0.6)) + assert dr.allclose(result2, 2.6) + + assert my_fn.n_recordings == 2 + + # Disable frozen kernels globally, now the freezing + # logic should be completely bypassed + with dr.scoped_set_flag(dr.JitFlag.KernelFreezing, False): + result3 = my_fn(one, one, c=0.9) + assert dr.allclose(result3, 2.9) + + +# @pytest.mark.parametrize("struct_style", ["drjit", "dataclass"]) +@pytest.mark.parametrize("struct_style", ["drjit", "dataclass"]) +# @pytest.test_arrays("float32, llvm, jit, -is_diff, shape=(*)") +@pytest.test_arrays("float32, jit, shape=(*)") +def test35_return_types(t, struct_style): + # WARN: only working on CUDA! + mod = sys.modules[t.__module__] + Float = t + Array3f = mod.Array3f + UInt32 = mod.UInt32 + + import numpy as np + + if struct_style == "drjit": + + class ToyDataclass: + DRJIT_STRUCT: dict = {"a": Float, "b": Float} + a: Float + b: Float + + def __init__(self, a=None, b=None): + self.a = a + self.b = b + + else: + assert struct_style == "dataclass" + + @dataclass(frozen=True) + class ToyDataclass: + a: Float + b: Float + + # 1. Many different types + @dr.freeze + def toy1(x: Float) -> Float: + y = x**2 + dr.sin(x) + z = x**2 + dr.cos(x) + return (x, y, z, ToyDataclass(a=x, b=y), {"x": x, "yi": UInt32(y)}, [[[[x]]]]) + + for i in range(2): + input = Float(np.full(17, i)) + # input = dr.full(Float, i, 17) + result = toy1(input) + assert isinstance(result[0], Float) + assert isinstance(result[1], Float) + assert isinstance(result[2], Float) + assert isinstance(result[3], ToyDataclass) + assert isinstance(result[4], dict) + assert result[4].keys() == set(("x", "yi")) + assert isinstance(result[4]["x"], Float) + assert isinstance(result[4]["yi"], UInt32) + assert isinstance(result[5], list) + assert isinstance(result[5][0], list) + assert isinstance(result[5][0][0], list) + assert isinstance(result[5][0][0][0], list) + + # 2. Many different types + @dr.freeze + def toy2(x: Float, target: Float) -> Float: + dr.scatter(target, 0.5 + x, dr.arange(UInt32, dr.width(x))) + return None + + for i in range(3): + input = Float([i] * 17) + target = dr.opaque(Float, 0, dr.width(input)) + # target = dr.full(Float, 0, dr.width(input)) + # target = dr.empty(Float, dr.width(input)) + + result = toy2(input, target) + assert dr.allclose(target, 0.5 + input) + assert result is None + + # 3. DRJIT_STRUCT as input and returning nested dictionaries + @dr.freeze + def toy3(x: Float, y: ToyDataclass) -> Float: + x_d = dr.detach(x, preserve_type=False) + return { + "a": x, + "b": (x, UInt32(2 * y.a + y.b)), + "c": None, + "d": { + "d1": x + x, + "d2": Array3f(x_d, -x_d, 2 * x_d), + "d3": None, + "d4": {}, + "d5": tuple(), + "d6": list(), + "d7": ToyDataclass(a=x, b=2 * x), + }, + "e": [x, {"e1": None}], + } + + for i in range(3): + input = Float([i] * 5) + input2 = ToyDataclass(a=input, b=Float(4.0)) + result = toy3(input, input2) + assert isinstance(result, dict) + assert isinstance(result["a"], Float) + assert isinstance(result["b"], tuple) + assert isinstance(result["b"][0], Float) + assert isinstance(result["b"][1], UInt32) + assert result["c"] is None + assert isinstance(result["d"], dict) + assert isinstance(result["d"]["d1"], Float) + assert isinstance(result["d"]["d2"], Array3f) + assert result["d"]["d3"] is None + assert isinstance(result["d"]["d4"], dict) and len(result["d"]["d4"]) == 0 + assert isinstance(result["d"]["d5"], tuple) and len(result["d"]["d5"]) == 0 + assert isinstance(result["d"]["d6"], list) and len(result["d"]["d6"]) == 0 + assert isinstance(result["d"]["d7"], ToyDataclass) + assert dr.allclose(result["d"]["d7"].a, input) + assert dr.allclose(result["d"]["d7"].b, 2 * input) + assert isinstance(result["e"], list) + assert isinstance(result["e"][0], Float) + assert isinstance(result["e"][1], dict) + assert result["e"][1]["e1"] is None + + +@pytest.test_arrays("float32, jit, shape=(*)") +def test36_drjit_struct_and_matrix(t): + package = sys.modules[t.__module__] + Float = package.Float + Array4f = package.Array4f + Matrix4f = package.Matrix4f + + class MyTransform4f: + DRJIT_STRUCT = { + "matrix": Matrix4f, + "inverse": Matrix4f, + } + + def __init__(self, matrix: Matrix4f = None, inverse: Matrix4f = None): + self.matrix = matrix + self.inverse = inverse + + @dataclass(frozen=False) + class Camera: + to_world: MyTransform4f + + @dataclass(frozen=False) + class Batch: + camera: Camera + value: float = 0.5 + offset: float = 0.5 + + @dataclass(frozen=False) + class Result: + value: Float + constant: int = 5 + + def fun(batch: Batch, x: Array4f): + res1 = batch.camera.to_world.matrix @ x + res2 = batch.camera.to_world.matrix @ x + batch.offset + res3 = batch.value + x + res4 = Result(value=batch.value) + return res1, res2, res3, res4 + + fun_frozen = dr.freeze(fun) + + n = 7 + for i in range(4): + x = Array4f( + *(dr.linspace(Float, 0, 1, n) + dr.opaque(Float, i) + k for k in range(4)) + ) + mat = Matrix4f( + *( + dr.linspace(Float, 0, 1, n) + dr.opaque(Float, i) + ii + jj + for jj in range(4) + for ii in range(4) + ) + ) + trafo = MyTransform4f() + trafo.matrix = mat + trafo.inverse = dr.rcp(mat) + + batch = Batch( + camera=Camera(to_world=trafo), + value=dr.linspace(Float, -1, 0, n) - dr.opaque(Float, i), + ) + # dr.eval(x, trafo, batch.value) + + results = fun_frozen(batch, x) + expected = fun(batch, x) + + assert len(results) == len(expected) + for result_i, (value, expected) in enumerate(zip(results, expected)): + + assert type(value) == type(expected) + if isinstance(value, Result): + value = value.value + expected = expected.value + assert dr.allclose(value, expected), str(result_i) + + +@pytest.test_arrays("float32, jit, shape=(*)") +def test37_with_dataclass_in_out(t): + mod = sys.modules[t.__module__] + Int32 = mod.Int32 + UInt32 = mod.UInt32 + Bool = mod.Bool + + @dataclass(frozen=False) + class MyRecord: + step_in_segment: Int32 + total_steps: UInt32 + short_segment: Bool + + def acc_fn(record: MyRecord): + record.step_in_segment += Int32(2) + return Int32(record.total_steps + record.step_in_segment) + + # Initialize MyRecord + n_rays = 100 + record = MyRecord( + step_in_segment=UInt32([1] * n_rays), + total_steps=UInt32([0] * n_rays), + short_segment=dr.zeros(Bool, n_rays), + ) + + # Create frozen kernel that contains another function + frozen_acc_fn = dr.freeze(acc_fn) + + accumulation = dr.zeros(UInt32, n_rays) + n_iter = 12 + for _ in range(n_iter): + accumulation += frozen_acc_fn(record) + + expected = 0 + for i in range(n_iter): + expected += 0 + 2 * (i + 1) + 1 + assert dr.all(accumulation == expected) + + +@pytest.test_arrays("float32, jit, shape=(*)") +def test38_allocated_scratch_buffer(t): + """ + Frozen functions may want to allocate some scratch space, scatter to it + in a first kernel, and read / use the values later on. As long as the + size of the scratch space can be guessed (e.g. a multiple of the launch width, + or matching the width of an existing input), we should be able to support it. + + On the other hand, the "scattering to an unknown buffer" pattern may actually + be scattering to an actual pre-existing buffer, which the user simply forgot + to include in the `state` lambda. In order to catch that case, we at least + check that the "scratch buffer" was read from in one of the kernels. + Otherwise, we assume it was meant as a side-effect into a pre-existing buffer. + """ + mod = sys.modules[t.__module__] + # dr.set_flag(dr.JitFlag.KernelFreezing, False) + UInt32 = mod.UInt32 + + # Note: we are going through an object / method, otherwise the closure + # checker would already catch the `forgotten_target_buffer` usage. + class Model: + DRJIT_STRUCT = { + "some_state": UInt32, + # "forgotten_target_buffer": UInt32, + } + + def __init__(self): + self.some_state = UInt32([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + self.forgotten_target_buffer = self.some_state + 1 + dr.eval(self.some_state, self.forgotten_target_buffer) + + @dr.freeze + def fn1(self, x): + # Note: assuming here that the width of `forgotten_target_buffer` doesn't change + index = dr.arange(UInt32, dr.width(x)) % dr.width( + self.forgotten_target_buffer + ) + dr.scatter(self.forgotten_target_buffer, x, index) + + return 2 * x + + @dr.freeze + def fn2(self, x): + # Scratch buffer with width equal to a state variable + scratch = dr.zeros(UInt32, dr.width(self.some_state)) + # Kernel 1: write to `scratch` + index = dr.arange(UInt32, dr.width(x)) % dr.width(self.some_state) + dr.scatter(scratch, x, index) + # Kernel 2: use values from `scratch` directly + result = dr.square(scratch) + # We don't actually return `scratch`, its lifetime is limited to the frozen function. + return result + + @dr.freeze + def fn3(self, x): + # Scratch buffer with width equal to a state variable + scratch = dr.zeros(UInt32, dr.width(self.some_state)) + # Kernel 1: write to `scratch` + index = dr.arange(UInt32, dr.width(x)) % dr.width(self.some_state) + dr.scatter(scratch, x, index) + # Kernel 2: use values from `scratch` via a gather + result = x + dr.gather(UInt32, scratch, index) + # We don't actually return `scratch`, its lifetime is limited to the frozen function. + return result + + model = Model() + + # Suspicious usage, should not allow it to avoid silent surprising behavior + for i in range(4): + x = UInt32(list(range(i + 2))) + assert dr.width(x) < dr.width(model.forgotten_target_buffer) + + if dr.flag(dr.JitFlag.KernelFreezing): + with pytest.raises(RuntimeError): + result = model.fn1(x) + break + + else: + result = model.fn1(x) + assert dr.allclose(result, 2 * x) + + expected = UInt32(model.some_state + 1) + dr.scatter(expected, x, dr.arange(UInt32, dr.width(x))) + assert dr.allclose(model.forgotten_target_buffer, expected) + + # Expected usage, we should allocate the buffer and allow the launch + for i in range(4): + x = UInt32(list(range(i + 2))) # i+1 + assert dr.width(x) < dr.width(model.some_state) + result = model.fn2(x) + expected = dr.zeros(UInt32, dr.width(model.some_state)) + dr.scatter(expected, x, dr.arange(UInt32, dr.width(x))) + assert dr.allclose(result, dr.square(expected)) + + # Expected usage, we should allocate the buffer and allow the launch + for i in range(4): + x = UInt32(list(range(i + 2))) # i+1 + assert dr.width(x) < dr.width(model.some_state) + result = model.fn3(x) + assert dr.allclose(result, 2 * x) + + +@pytest.test_arrays("float32, jit, shape=(*)") +def test39_simple_reductions(t): + import numpy as np + + mod = sys.modules[t.__module__] + Float = mod.Float32 + n = 37 + + @dr.freeze + def simple_sum(x): + return dr.sum(x) + + @dr.freeze + def simple_product(x): + return dr.prod(x) + + @dr.freeze + def simple_min(x): + return dr.min(x) + + @dr.freeze + def simple_max(x): + return dr.max(x) + + @dr.freeze + def sum_not_returned_wide(x): + return dr.sum(x) + x + + @dr.freeze + def sum_not_returned_single(x): + return dr.sum(x) + 4 + + def check_expected(fn, expected): + result = fn(x) + + assert dr.width(result) == dr.width(expected) + assert isinstance(result, Float) + assert dr.allclose(result, expected) + + for i in range(3): + x = dr.linspace(Float, 0, 1, n) + dr.opaque(Float, i) + + x_np = x.numpy() + check_expected(simple_sum, np.sum(x_np).item()) + check_expected(simple_product, np.prod(x_np).item()) + check_expected(simple_min, np.min(x_np).item()) + check_expected(simple_max, np.max(x_np).item()) + + check_expected(sum_not_returned_wide, np.sum(x_np).item() + x) + check_expected(sum_not_returned_single, np.sum(x_np).item() + 4) + + +@pytest.test_arrays("float32, jit, shape=(*)") +def test40_prefix_reductions(t): + import numpy as np + + mod = sys.modules[t.__module__] + Float = mod.Float32 + n = 37 + + @dr.freeze + def prefix_sum(x): + return dr.prefix_reduce(dr.ReduceOp.Add, x, exclusive=False) + + def check_expected(fn, expected): + result = fn(x) + + assert dr.width(result) == dr.width(expected) + assert isinstance(result, Float) + assert dr.allclose(result, expected) + + for i in range(3): + x = dr.linspace(Float, 0, 1, n) + dr.opaque(Float, i) + + x_np = x.numpy() + check_expected(prefix_sum, Float(np.cumsum(x_np))) + + +@pytest.test_arrays("float32, jit, is_diff, shape=(*)") +def test41_reductions_with_ad(t): + Float = t + n = 37 + + @dr.freeze + def sum_with_ad(x, width_opaque): + intermediate = 2 * x + 1 + dr.enable_grad(intermediate) + + result = dr.square(intermediate) + + # Unfortunately, as long as we don't support creating opaque values + # within a frozen kernel, we can't use `dr.mean()` directly. + loss = dr.sum(result) / width_opaque + # loss = dr.mean(result) + dr.backward(loss) + + return result, intermediate + + @dr.freeze + def product_with_ad(x): + dr.enable_grad(x) + loss = dr.prod(x) + dr.backward_from(loss) + + for i in range(3): + x = dr.linspace(Float, 0, 1, n + i) + dr.opaque(Float, i) + result, intermediate = sum_with_ad(x, dr.opaque(Float, dr.width(x))) + assert dr.width(result) == n + i + + assert dr.grad_enabled(result) + assert dr.grad_enabled(intermediate) + assert not dr.grad_enabled(x) + intermediate_expected = 2 * x + 1 + assert dr.allclose(intermediate, intermediate_expected) + assert dr.allclose(result, dr.square(intermediate_expected)) + assert sum_with_ad.n_recordings == 1 + assert dr.allclose(dr.grad(result), 0) + assert dr.allclose( + dr.grad(intermediate), 2 * intermediate_expected / dr.width(x) + ) + + for i in range(3): + x = dr.linspace(Float, 0.1, 1, n + i) + dr.opaque(Float, i) + result = product_with_ad(x) + + assert result is None + assert dr.grad_enabled(x) + with dr.suspend_grad(): + expected_grad = dr.prod(x) / x + assert dr.allclose(dr.grad(x), expected_grad) + + +# @pytest.test_arrays("float32, jit, shape=(*)") +# def test35_mean(t): +# def func(x): +# return dr.mean(x) +# +# frozen_func = dr.freeze(func) +# +# n = 10 +# +# for i in range(3): +# x = dr.linspace(t, 0, 1, n + i) + dr.opaque(t, i) +# +# result = frozen_func(x) +# assert dr.allclose(result, func(x)) + + +@pytest.test_arrays("float32, jit, shape=(*)") +def test42_size_aliasing(t): + def func(x, y): + return x + 1, y + 2 + + frozen_func = dr.freeze(func) + + n = 3 + + for i in range(3): + x = dr.linspace(t, 0, 1, n) + dr.opaque(t, i) + y = dr.linspace(t, 0, 1, n + i) + dr.opaque(t, i) + + result = frozen_func(x, y) + + assert dr.allclose(result, func(x, y)) + + """ + We should have two recordings, one for which the experssions x+1 and y+1 are compiled into the same kernel, + because x and y have the same size and one where they are compiled seperately because their sizes are different. + """ + assert frozen_func.n_recordings == 2 + + +@pytest.test_arrays("float32, jit, -is_diff, shape=(*)") +def test43_pointer_aliasing(t): + """ + Dr.Jit employs a memory cache, which means that two variables + get allocated the same memory region, if one is destroyed + before the other is created. + Since we track variables using their pointers in the `RecordThreadState`, + we have to update the `ptr_to_slot` mapping for new variables. + """ + + n = 4 + + def func(x): + y = x + 1 + dr.make_opaque(y) + for i in range(3): + y = y + 1 + dr.make_opaque(y) + return y + + for i in range(10): + frozen_func = dr.freeze(func) + + x = dr.linspace(t, 0, 1, n) + dr.opaque(t, i) + assert dr.allclose(frozen_func(x), func(x)) + + x = dr.linspace(t, 0, 1, n) + dr.opaque(t, i) + assert dr.allclose(frozen_func(x), func(x)) + + +@pytest.test_arrays("float32, jit, is_diff, shape=(*)") +def test44_simple_ad_fully_inside(t): + mod = sys.modules[t.__module__] + Float = mod.Float + + def my_kernel(x): + dr.enable_grad(x) + + result = x * x + dr.backward(result) + + return result + + for start_enabled in (True, False): + # Re-freeze + my_kernel_frozen = dr.freeze(my_kernel) + + for i in range(3): + x = Float([1.0, 2.0, 3.0]) + dr.opaque(Float, i) + if start_enabled: + dr.enable_grad(x) + + y = my_kernel_frozen(x) + grad_x = dr.grad(x) + grad_y = dr.grad(y) + dr.schedule(y, grad_x, grad_y) + assert dr.allclose(y, dr.square(x)) + assert dr.allclose(grad_y, 0) + assert dr.allclose(grad_x, 2 * x) + + # Status of grad_enabled should be restored (side-effect of the function), + # even if it wasn't enabled at first + assert dr.grad_enabled(x) + + +@pytest.mark.parametrize("set_some_literal_grad", (False,)) +@pytest.mark.parametrize("inputs_end_enabled", (True, False)) +@pytest.mark.parametrize("inputs_start_enabled", (True,)) +@pytest.mark.parametrize("params_end_enabled", (False, False)) +@pytest.mark.parametrize("params_start_enabled", (True,)) +@pytest.mark.parametrize("freeze", (True,)) +@pytest.test_arrays("float32, jit, is_diff, shape=(*)") +def test45_suspend_resume( + t, + params_start_enabled, + params_end_enabled, + inputs_start_enabled, + inputs_end_enabled, + set_some_literal_grad, + freeze, +): + + mod = sys.modules[t.__module__] + Float = mod.Float32 + UInt32 = mod.UInt32 + log_level = dr.log_level() + + # TODO: remove this + # dr.set_flag(dr.JitFlag.KernelFreezing, False) + + class MyModel: + DRJIT_STRUCT = {"params": Float} + + def __init__(self, params): + self.params = params + self.frozen_eval = dr.freeze(type(self).eval) if freeze else type(self).eval + + def eval( + self, + x: Float, + params_end_enabled: bool, + inputs_end_enabled: bool, + set_some_literal_grad: bool, + ): + idx = dr.arange(UInt32, dr.width(x)) % dr.width(self.params) + latents = dr.gather(Float, self.params, idx) + result = x * latents + + with dr.resume_grad(): + dr.set_grad_enabled(self.params, params_end_enabled) + dr.set_grad_enabled(x, inputs_end_enabled) + if set_some_literal_grad: + # If grads are not enabled, this will get ignored, which is fine + dr.set_grad(x, Float(6.66)) + + return result + + model = MyModel(params=Float([1, 2, 3, 4, 5])) + + for i in range(3): + # Inputs of different widths + x = Float([0.1, 0.2, 0.3, 0.4, 0.5, 0.6] * (i + 1)) + dr.opaque(Float, i) + + dr.set_grad_enabled(model.params, params_start_enabled) + dr.set_grad_enabled(x, inputs_start_enabled) + + dr.eval(x, dr.grad(x)) + + with dr.suspend_grad(): + result = model.frozen_eval( + model, x, params_end_enabled, inputs_end_enabled, set_some_literal_grad + ) + + # dr.eval(result, model.params, dr.grad(model.params)) + assert not dr.grad_enabled(result) + assert dr.grad_enabled(model.params) == params_end_enabled + assert dr.grad_enabled(x) == inputs_end_enabled + + # The frozen function should restore the right width, even for a zero-valued literal. + # The default gradients are a zero-valued literal array + # with a width equal to the array's width + grads = dr.grad(model.params) + assert dr.width(grads) == dr.width(model.params) + assert dr.all(grads == 0) + + grads = dr.grad(x) + assert dr.width(grads) == dr.width(x) + if inputs_end_enabled and set_some_literal_grad: + assert dr.all(grads == 6.66) + else: + assert dr.all(grads == 0) + + assert model.frozen_eval.n_recordings == 1 + + +@pytest.test_arrays("float32, jit, is_diff, shape=(*)") +@pytest.mark.parametrize("freeze", (True,)) +@pytest.mark.parametrize("change_params_width", (False,)) +def test46_with_grad_scatter(t, freeze: bool, change_params_width): + mod = sys.modules[t.__module__] + Float = mod.Float32 + UInt32 = mod.UInt32 + log_level = dr.log_level() + + class Model: + DRJIT_STRUCT = {"params": Float} + + def __init__(self, n): + self.params = Float(list(range(1, n + 1))) + assert dr.width(self.params) == n + dr.enable_grad(self.params) + + def __call__(self): + # Cheeky workaround for the frozen kernel signature checking + pass + + def my_kernel(model, x, opaque_params_width): + idx = dr.arange(UInt32, dr.width(x)) % opaque_params_width + + with dr.resume_grad(): + latents = dr.gather(Float, model.params, idx) + contrib = x * latents + dr.backward_from(contrib) + + return dr.detach(contrib) + + model = Model(5) + my_kernel_frozen = dr.freeze(my_kernel) if freeze else my_kernel + + for i in range(6): + # Different width at each iteration + x = Float([1.0, 2.0, 3.0] * (i + 1)) + dr.opaque(Float, i) + + # The frozen kernel should also support the params (and therefore its gradient buffer) + # changing width without issues. + if change_params_width and (i == 3): + model = Model(10) + # Reset gradients + dr.set_grad(model.params, 0) + assert dr.grad_enabled(model.params) + + with dr.suspend_grad(): + y = my_kernel_frozen(model, x, dr.opaque(UInt32, dr.width(model.params))) + assert not dr.grad_enabled(x) + assert not dr.grad_enabled(y) + assert dr.grad_enabled(model.params) + + grad_x = dr.grad(x) + grad_y = dr.grad(y) + grad_p = dr.grad(model.params) + # assert dr.allclose(y, dr.sqr(x)) + + # Expected grads + assert dr.allclose(grad_y, 0) + assert dr.allclose(grad_x, 0) + grad_p_expected = dr.zeros(Float, dr.width(model.params)) + idx = dr.arange(UInt32, dr.width(x)) % dr.width(model.params) + dr.scatter_reduce(dr.ReduceOp.Add, grad_p_expected, x, idx) + assert dr.allclose(grad_p, grad_p_expected) + + +@pytest.test_arrays("float32, jit, is_diff, shape=(*)") +def test47_tutorial_example(t): + mod = sys.modules[t.__module__] + Float = mod.Float32 + UInt32 = mod.UInt32 + + @dr.freeze + def frozen_eval(inputs, idx, params, target_value, grad_factor): + intermediate = dr.gather(Float, params, idx) + result = 0.5 * dr.square(intermediate) * inputs + + # Since reductions are not supported yet, we cannot compute a single + # loss value here. It's not really a problem though, since DrJit can + # backpropagate starting from arrays of any widths. + loss_per_entry = dr.square(result - target_value) * grad_factor + + # The gradients resulting from backpropagation will be directly accumulated + # (via dr.scatter_add()) into the gradient buffer of `params` (= `dr.grad(params)`). + dr.backward_from(loss_per_entry) + + # It's fine to return the primal values of `result`, but keep in mind that they will + # not be differentiable w.r.t. `params`. + return dr.detach(result) + + params = Float([1, 2, 3, 4, 5]) + + for _ in range(3): + dr.disable_grad(params) + dr.enable_grad(params) + assert dr.all(dr.grad(params) == 0) + + inputs = Float([0.1, 0.2, 0.3]) + idx = UInt32([1, 2, 3]) + # Represents the optimizer's loss scale + grad_factor = 4096 / dr.opaque(Float, dr.width(inputs)) + + result = frozen_eval( + inputs, idx, params, target_value=0.5, grad_factor=grad_factor + ) + assert not dr.grad_enabled(result) + # Gradients were correctly accumulated to `params`'s gradients. + assert not dr.all(dr.grad(params) == 0) + + +@pytest.test_arrays("uint32, jit, shape=(*)") +def test48_compress(t): + + mod = sys.modules[t.__module__] + Float = mod.Float32 + UInt32 = mod.UInt32 + + pkg = get_pkg(t) + Sampler = pkg.Sampler + + def func(sampler: Sampler) -> UInt32: + indices = dr.compress(sampler.next() < 0.5) + return indices + + frozen = dr.freeze(func) + + sampler_func = Sampler(10) + sampler_frozen = Sampler(10) + for i in range(3): + dr.all(frozen(sampler_frozen) == func(sampler_func)) + + sampler_func = Sampler(11) + sampler_frozen = Sampler(11) + for i in range(3): + dr.all(frozen(sampler_frozen) == func(sampler_func)) + + assert frozen.n_recordings == 1 + + +@pytest.test_arrays("uint32, llvm, -is_diff, jit, shape=(*)") +def test49_scatter_reduce_expanded(t): + + def func(target: t, src: t): + dr.scatter_reduce(dr.ReduceOp.Add, target, src, dr.arange(t, dr.width(src)) % 2) + + frozen = dr.freeze(func) + + for i in range(4): + src = dr.full(t, 1, 10 + i) + dr.make_opaque(src) + + result = t([0] * (i + 2)) + dr.make_opaque(result) + frozen(result, src) + + reference = t([0] * (i + 2)) + dr.make_opaque(reference) + func(reference, src) + + assert dr.all(result == reference) + + assert frozen.n_cached_recordings == 1 + assert frozen.n_recordings == 4 + + +@pytest.test_arrays("uint32, llvm, -is_diff, jit, shape=(*)") +def test50_scatter_reduce_expanded_identity(t): + + def func(src: t): + target = dr.zeros(t, 5) + dr.scatter_reduce(dr.ReduceOp.Add, target, src, dr.arange(t, dr.width(src)) % 2) + return target + + frozen = dr.freeze(func) + + for i in range(4): + src = dr.full(t, 1, 10 + i) + dr.make_opaque(src) + + result = frozen(src) + + reference = func(src) + + assert dr.all(result == reference) + + assert frozen.n_recordings == 1 + + +@pytest.test_arrays("uint32, llvm, -is_diff, jit, shape=(*)") +def test51_scatter_reduce_expanded_no_memset(t): + + def func(src: t): + target = dr.full(t, 5) + dr.scatter_reduce(dr.ReduceOp.Add, target, src, dr.arange(t, dr.width(src)) % 2) + return target + + frozen = dr.freeze(func) + + for i in range(4): + src = dr.full(t, 1, 10 + i) + dr.make_opaque(src) + + result = frozen(src) + + reference = func(src) + + assert dr.all(result == reference) + + assert frozen.n_recordings == 1 + assert frozen.n_cached_recordings == 1 + + +@pytest.test_arrays("uint32, jit, shape=(*)") +def test52_python_inputs(t): + + def func(x: t, neg: bool): + if neg: + return -x + 1 + else: + return x + 1 + + frozen = dr.freeze(func) + + for i in range(3): + for neg in [False, True]: + x = t(1, 2, 3) + dr.opaque(t, i) + + res = frozen(x, neg) + ref = func(x, neg) + assert dr.all(res == ref) + + assert frozen.n_recordings == 2 + + +@pytest.test_arrays("uint32, jit, shape=(*)") +def test53_scatter_inc(t): + + mod = sys.modules[t.__module__] + Float = mod.Float32 + UInt32 = mod.UInt32 + + n = 37 + + def acc_with_scatter_inc(x, counter, out_buffer, max_points): + active = x > 0.5 + out_idx = dr.scatter_inc(counter, UInt32(0), active=active) + active &= out_idx < max_points + + # TODO: also test within a loop + dr.scatter(out_buffer, x, out_idx, active=active) + + def test(i, func): + x = dr.linspace(Float, 0.1, 1, n + i) + dr.opaque(Float, i) / 100 + counter = UInt32(dr.opaque(UInt32, 0)) + out_buffer = dr.zeros(Float, 10) + max_points = UInt32(dr.opaque(UInt32, dr.width(out_buffer))) + + dr.set_label(x, "x") + dr.set_label(counter, "counter") + dr.set_label(out_buffer, "out_buffer") + dr.set_label(max_points, "max_points") + dr.eval(x, counter, out_buffer, max_points) + + func(x, counter, out_buffer, max_points) + + return out_buffer, counter + + def func(i): + return test(i, acc_with_scatter_inc) + + acc_with_scatter_inc = dr.freeze(acc_with_scatter_inc) + + def frozen(i): + return test(i, acc_with_scatter_inc) + + for i in range(3): + + res, _ = frozen(i) + ref, _ = func(i) + + # assert dr.all(res == ref) + + # Should have filled all of the entries + assert dr.all(res > 0.5) + + +@pytest.test_arrays("uint32, jit, shape=(*)") +def test54_read_while_frozen(t): + # dr.set_flag(dr.JitFlag.KernelFreezing, True) + assert dr.flag(dr.JitFlag.KernelFreezing) + + def func(x): + return x[1] + + frozen = dr.freeze(func) + + x = t(1, 2, 3) + with pytest.raises(RuntimeError): + frozen(x) + + +@pytest.test_arrays("float32, jit, diff, shape=(*)") +def test55_var_upload(t): + def func(x): + + arrays = [] + + for i in range(3): + y = dr.arange(t, 3) + dr.make_opaque(y) + arrays.append(y) + + del arrays + del y + + return x / t(10, 10, 10) + + frozen = dr.freeze(func) + + for i in range(3): + x = dr.arange(t, 3) + dr.make_opaque(x) + + # with pytest.raises(RuntimeError, match = "created while recording"): + with pytest.raises(RuntimeError): + z = frozen(x) + + # assert dr.allclose(z, func(x)) + + +@pytest.test_arrays("float32, jit, diff, shape=(*)") +def test56_grad_isolate(t): + dr.set_flag(dr.JitFlag.ReuseIndices, False) + + def f(x): + return x * 2 + + def g(y): + return y * 3 + + def func(y): + z = g(y) + dr.backward(z) + + frozen = dr.freeze(func) + + for i in range(3): + + x = dr.arange(t, 3) + dr.make_opaque(x) + dr.enable_grad(x) + + y = f(x) + with dr.isolate_grad(): + func(y) + + ref = dr.grad(x) + + x = dr.arange(t, 3) + dr.make_opaque(x) + dr.enable_grad(x) + + y = f(x) + dr.make_opaque(y) + frozen(y) + + res = dr.grad(x) + + assert dr.allclose(ref, res) + + +@pytest.test_arrays("float32, jit, diff, shape=(*)") +def test57_isolate_grad_fwd(t): + + def f(x): + return x * x + + def g(y): + return y * 2 + + def func(x): + with dr.isolate_grad(): + y = f(x) + dr.forward(x) + return y + + def frozen(x): + y = f(x) + dr.forward(x) + return y + + frozen = dr.freeze(frozen) + + for i in range(3): + x = t(i) + dr.make_opaque(x) + dr.enable_grad(x) + + y = func(x) + # z = g(y) + + ref = dr.grad(y) + + x = t(i) + dr.make_opaque(x) + dr.enable_grad(x) + + y = frozen(x) + # z = g(y) + + res = dr.grad(y) + + assert dr.allclose(ref, res) + + +@pytest.test_arrays("float32, jit, diff, shape=(*)") +def test58_grad_postponed_part(t): + dr.set_flag(dr.JitFlag.ReuseIndices, False) + + def f(x): + return x * x * 2 + + def g(y): + return y * y * 3 + + def func(y1, y2): + z1 = g(y1) + z2 = g(y2) + dr.backward(z1) + + frozen = dr.freeze(func) + + def run(i, name, func): + x1 = dr.arange(t, 3) + i + dr.make_opaque(x1) + dr.enable_grad(x1) + y1 = f(x1) + + x2 = dr.arange(t, 3) + i + dr.make_opaque(x2) + dr.enable_grad(x2) + dr.set_grad(x2, 2) + y2 = f(x2) + + func(y1, y2) + + dx1 = dr.grad(x1) + dx2_1 = dr.grad(x2) + + dr.set_grad(x2, 1) + dr.backward(x2) + dx1_2 = dr.grad(x2) + + return [dx1, dx1_2, dx2_1] + + for i in range(3): + + def isolated(y1, y2): + with dr.isolate_grad(): + func(y1, y2) + + ref = run(i, "reference", isolated) + res = run(i, "frozen", frozen) + + for ref, res in zip(ref, res): + assert dr.allclose(ref, res) + + +@pytest.test_arrays("float32, jit, diff, shape=(*)") +def test59_nested(t): + + pkg = get_pkg(t) + mod = sys.modules[t.__module__] + + A, B, Base, BasePtr = pkg.A, pkg.B, pkg.Base, pkg.BasePtr + a, b = A(), B() + a.value = dr.ones(t, 16) + dr.enable_grad(a.value) + + U = mod.UInt32 + xi = t(1, 2, 8, 3, 4) + yi = dr.reinterpret_array(U, BasePtr(a, a, a, a, a)) + + def nested(self, xi, yi): + return self.nested(xi, yi) + + def func(c, xi, yi): + return dr.dispatch(c, nested, xi, yi) + + frozen = dr.freeze(func) + + for i in range(3): + c = BasePtr(a, a, a, b, b) + xi = t(1, 2, 8, 3, 4) + yi = dr.reinterpret_array(U, BasePtr(a, a, a, a, a)) + + with dr.scoped_set_flag(dr.JitFlag.SymbolicCalls, True): + xref = func(c, xi, yi) + + assert dr.all(xref == xi + 1) + + c = BasePtr(a, a, a, b, b) + xi = t(1, 2, 8, 3, 4) + yi = dr.reinterpret_array(U, BasePtr(a, a, a, a, a)) + + with dr.scoped_set_flag(dr.JitFlag.SymbolicCalls, True): + xfrozen = frozen(c, xi, yi) + + assert dr.all(xfrozen == xref) + + +@pytest.test_arrays("float32, jit, diff, shape=(*)") +def test60_call_raise(t): + + mod = sys.modules[t.__module__] + pkg = get_pkg(t) + + UInt = mod.UInt + + A, B, Base, BasePtr = pkg.A, pkg.B, pkg.Base, pkg.BasePtr + a, b = A(), B() + + def f(x: t): + raise RuntimeError("test") + + def g(self, x: t): + if isinstance(self, B): + raise RuntimeError + return x + 1 + + c = BasePtr(a, a, a, b, b) + + with pytest.raises(RuntimeError): + dr.dispatch(c, g, t(1, 1, 2, 2, 2)) + + +@pytest.test_arrays("float32, jit, diff, shape=(*)") +def test61_reduce_dot(t): + def func(x, y): + return dr.dot(x, y) + + frozen = dr.freeze(func) + + for i in range(3): + x = dr.arange(t, 10 + i) + y = dr.arange(t, 10 + i) + + result = frozen(x, y) + reference = func(x, y) + + assert dr.allclose(result, reference) + + assert frozen.n_recordings == 1 + + +@pytest.test_arrays("float32, jit, shape=(*)") +def test62_clear(t): + @dr.freeze + def func(x): + return x + 1 + + x = dr.arange(t, 10) + y = func(x) + assert func.n_recordings == 1 + + func.clear() + assert func.n_recordings == 0 + + x = dr.arange(t, 10) + y = func(x) + assert func.n_recordings == 1 + + +@pytest.test_arrays("uint32, jit, shape=(*)") +def test63_method_decorator(t): + mod = sys.modules[t.__module__] + + class Custom: + DRJIT_STRUCT = {"state": t} + + def __init__(self) -> None: + self.state = t([1, 2, 3]) + + @dr.freeze + def frozen(self, x): + return x + self.state + + def func(self, x): + return x + self.state + + c = Custom() + for i in range(3): + x = dr.arange(t, 3) + i + dr.make_opaque(x) + res = c.frozen(x) + ref = c.func(x) + + assert dr.all(res == ref) + + +@pytest.test_arrays("float32, jit, shape=(*)") +def test64_tensor(t): + """ + Tests that constructing tensors in frozen functions is possible, and does + not cause leaks. + """ + mod = sys.modules[t.__module__] + Float32 = mod.Float32 + TensorXf = mod.TensorXf + + def func(x): + return TensorXf(x + 1) + + frozen = dr.freeze(func) + + for i in range(3): + x = dr.arange(Float32, 100) + ref = func(x) + res = frozen(x) + assert dr.all(res == ref) + + +@pytest.test_arrays("float32, jit, shape=(*)") +def test65_assign_tensor(t): + """ + Tests that assigning tensors to the input of frozen functions is possible, + and does not cause leaks. + """ + mod = sys.modules[t.__module__] + Float32 = mod.Float32 + TensorXf = mod.TensorXf + + def func(x): + x += 1 + + frozen = dr.freeze(func) + + for i in range(3): + x = TensorXf(dr.arange(Float32, 100)) + func(x) + ref = x + + x = TensorXf(dr.arange(Float32, 100)) + frozen(x) + res = x + assert dr.all(res == ref) + + +@pytest.test_arrays("float32, jit, shape=(*)") +def test66_closure(t): + + c1 = 1 + c2 = t(2) + + def func(x): + return x + c1 + c2 + + frozen = dr.freeze(func) + + for i in range(3): + x = dr.arange(t, i + 2) + ref = func(x) + + x = dr.arange(t, i + 2) + res = frozen(x) + + assert dr.allclose(ref, res) + + assert frozen.n_recordings == 1 + + +@pytest.test_arrays("float32, jit, shape=(*)") +def test67_mutable_closure(t): + """ + Test that it is possible to use and modify closures in frozen functions. + """ + y1 = t(1, 2, 3) + y2 = t(1, 2, 3) + + def func(x): + nonlocal y1 + y1 += x + + @dr.freeze + def frozen(x): + nonlocal y2 + y2 += x + + for i in range(3): + x = t(i) + dr.make_opaque(x) + + func(x) + frozen(x) + + assert dr.allclose(y1, y2) + + +@pytest.test_arrays("float32, jit, shape=(*)") +def test68_state_decorator(t): + mod = sys.modules[t.__module__] + Float = mod.Float32 + UInt32 = mod.UInt32 + + # Note: not a dataclass or DRJIT_STRUCT + class MyClass: + def __init__(self): + self.something1 = 4.5 + self.something2 = Float([1, 2, 3, 4, 5]) + + @dr.freeze(state_fn=lambda self, *_, **__: (self.something2)) + def frozen(self, x: Float, idx: UInt32) -> Float: + return x * self.something1 * dr.gather(Float, self.something2, idx) + + def func(self, x: Float, idx: UInt32) -> Float: + return x * self.something1 * dr.gather(Float, self.something2, idx) + + c = MyClass() + + for i in range(3): + x = dr.arange(Float, i + 2) + idx = dr.arange(UInt32, i + 2) + + res = c.frozen(x, idx) + ref = c.func(x, idx) + + assert dr.allclose(res, ref) + assert c.frozen.n_recordings == 1 + + +@pytest.test_arrays("float32, jit, shape=(*)") +@pytest.mark.parametrize("limit", (-1, None, 0, 1, 2)) +def test69_max_cache_size(t, limit): + """ + Tests different cache size limitations for the frozen function. + """ + + def func(x, p): + return x + p + + frozen = dr.freeze(func, limit=limit) + + n = 3 + for i in range(n): + + x = t(i, i+1, i+2) + + res = frozen(x, i) + ref = func(x, i) + + assert dr.allclose(res, ref) + + if limit == -1 or limit is None: + assert frozen.n_recordings == n + assert frozen.n_cached_recordings == n + elif limit == 0: + assert frozen.n_recordings == 0 + assert frozen.n_cached_recordings == 0 + else: + assert frozen.n_recordings == n + assert frozen.n_cached_recordings == limit + + +@pytest.test_arrays("float32, jit, shape=(*)") +def test69_lru_eviction(t): + """ + Tests that the least recently used cache entry is evicted from the frozen + function if the cache size is limited. + """ + + def func(x, p): + return x + p + + frozen = dr.freeze(func, limit=2) + + x = t(0, 1, 2) + + # Create two entries in the cache + frozen(x, 0) + frozen(x, 1) + + # This should evict the first one + frozen(x, 2) + + assert frozen.n_recordings == 3 + + # p = 1 should still be in the cache, and calling it should not increment + # the recording counter. + frozen(x, 1) + + assert frozen.n_recordings == 3 + + # p = 0 should be evicted, and calling it will increment the counter + frozen(x, 0) + + assert frozen.n_recordings == 4 + + +@pytest.test_arrays("float32, jit, shape=(*)") +def test70_warn_recordings(t): + """ + This test simply calls the frozen function with incompattible inputs, and + should print two warnings. + """ + + def func(x, i): + return x + i + + frozen = dr.freeze(func, warn_after=2) + + for i in range(4): + x = t(1, 2, 3) + frozen(x, i) + +@pytest.test_arrays("float32, jit, shape=(*)") +@pytest.mark.parametrize("force_optix", [True, False]) +def test71_texture(t, force_optix): + mod = sys.modules[t.__module__] + Texture1f = mod.Texture1f + Float = mod.Float32 + + def func(tex: Texture1f, pos: Float): + return tex.eval(pos) + + frozen = dr.freeze(func) + + with dr.scoped_set_flag(dr.JitFlag.ForceOptiX, force_optix): + n = 4 + for i in range(3): + tex = Texture1f([2], 1, True, dr.FilterMode.Linear, dr.WrapMode.Repeat) + tex.set_value(t(0, 1)) + + pos = dr.arange(Float, i+2) / n + + res = frozen(tex, pos) + ref = func(tex, pos) + + assert dr.allclose(res, ref) + + assert frozen.n_recordings < n + +@pytest.test_arrays("float32, jit, shape=(*)") +def test72_no_input(t): + mod = sys.modules[t.__module__] + + backend = dr.backend_v(t) + + def func(): + return dr.arange(t, 10) + + frozen = dr.freeze(func, backend = backend) + + for i in range(3): + res = frozen() + ref = func() + + assert dr.allclose(res, ref) + + wrong_backend = ( + dr.JitBackend.CUDA if backend == dr.JitBackend.LLVM else dr.JitBackend.LLVM + ) + + frozen = dr.freeze(func, backend=wrong_backend) + + with pytest.raises(RuntimeError): + for i in range(3): + res = frozen() + + +@pytest.test_arrays("float32, jit, shape=(*)") +def test76_changing_literal_width_holder(t): + + class MyHolder: + DRJIT_STRUCT = {"lit": t} + def __init__(self, lit): + self.lit = lit + + def func(x: t, lit: MyHolder): + return x + 1 + + # Note: only fails with auto_opaque=True + frozen = dr.freeze(func, warn_after=3) + + n = 10 + for i in range(n): + holder = MyHolder(dr.zeros(dr.tensor_t(t), (i+1) * 10)) + x = holder.lit + 0.5 + dr.make_opaque(x) + + res = frozen(x, holder) + ref = func(x, holder) + + assert dr.allclose(ref, res) + + assert frozen.n_recordings == 1 + +@pytest.test_arrays("float32, jit, diff, shape=(*)") +@pytest.mark.parametrize("optimizer", ["sdg", "rmsprop", "adam"]) +def test77_optimizers(t, optimizer): + n = 10 + + def func(y, opt): + loss = dr.mean(dr.square(opt["x"] - y)) + + dr.backward(loss) + + opt.step() + + return opt["x"], loss + + def init_optimizer(): + if optimizer == "sdg": + opt = dr.opt.SGD(lr = 0.001, momentum = 0.9) + elif optimizer == "rmsprop": + opt = dr.opt.RMSProp(lr = 0.001) + elif optimizer == "adam": + opt = dr.opt.Adam(lr = 0.001) + return opt + + frozen = dr.freeze(func) + + opt_func = init_optimizer() + opt_frozen = init_optimizer() + + for i in range(n): + x = dr.full(t, 1, 10) + y = dr.full(t, 0, 10) + + opt_func["x"] = x + opt_frozen["x"] = x + + res_x, res_loss = frozen(y, opt_frozen) + ref_x, ref_loss = func(y, opt_func) + + assert dr.allclose(res_x, ref_x) + assert dr.allclose(res_loss, ref_loss) diff --git a/tests/while_loop_ext.cpp b/tests/while_loop_ext.cpp index 64613e7f..e207a38a 100644 --- a/tests/while_loop_ext.cpp +++ b/tests/while_loop_ext.cpp @@ -3,6 +3,7 @@ #include #include #include +#include namespace nb = nanobind; namespace dr = drjit; @@ -37,11 +38,13 @@ struct Sampler { T next() { return rng.next_float32(); } - void traverse_1_cb_ro(void *payload, void (*fn)(void *, uint64_t)) const { + void traverse_1_cb_ro(void *payload, + dr::detail::traverse_callback_ro fn) const { traverse_1_fn_ro(rng, payload, fn); } - void traverse_1_cb_rw(void *payload, uint64_t (*fn)(void *, uint64_t)) { + void traverse_1_cb_rw(void *payload, + dr::detail::traverse_callback_rw fn) { traverse_1_fn_rw(rng, payload, fn); }