Skip to content

Commit d47afff

Browse files
Added TraversableBase class for c++ type traversal
Improved traversal of c++ iterators and references Using TraversableBase for Texture traversal Added traversal helpers for trampoline classes Fixed double definition of traverse_py_cb_* methods Changed traverse function, allowing assigment of c++ objects Added traversal tests Added ad_scope_postponed function Added freezing feature Added freezing tests Improved logging Cleanup of freeze.cpp Added declarations in freeze.h Reordered traversal functions Cleanup and added clear function Test ordering Removed outdated fix comment Improved exceptions Minor performance improvements Using nb::hash for key hashing Using the FreezingScope flag for passthrough recording Removed kw_only attribute from dataclasses in freeze tests Changed traversable base casting Wrapped DRJIT_EXPORT in ifndef Added warning for traversable base casting Temporary fix for registry traversal First freeze documentation attempt Improved registry traversal and insertion Cleanup freezing tests Updated drjit-core submodule Enabled retreiving the variant of c++ objects Added variance mismatch test Attempt at variant traversal in traverse_cb_ro Added registry domain traversal Moved declaration of `freeze` to python and added `traverse_callback_*` aliases Formatting and removed print in tests Removed default impls for Variant and Domain Fixed leak of FrozenFunctions by custom currying class Implemented domain traversal Added comments for frozen function tests Removed prints from tests Added comments for frozen function tests Fixed typo Fixed traverse test Remove b.value() Formatting Added comments for traversable_base.h and fixed some include patterns Put freezing profiler phase behind ndebug flag Added traversable_base.h comments Handled comments from Merlin Added comments in apply.h Moved freeze documentation to __init__.py Added suggestions from Merlin Added default value to resize for registry traversal Fixed unnecesarry format change Changed docstring comments for custom_type_ext tests Changed include style Comments Fixed tensors getting leaked Removed tracing log levels Fixed tensor assignment leak Improved profiling helper Using move only Layout Fixed variable missmatch Fixed out of bounds write Added fast callback traversal Using info from jit_set_backend` Improved traversal of opaque C++ objects Improved layout for ad/jit indices Reduced jit_var_inc_ref overhead Improved traversal api Removed redundant logging calls Added Profile flags behind DRJIT_ENABLE_NVTX Removed indices vector Added domain and variant to rw traversal functions Fixed warnings Fixed logging of key difference Fixed segfault with logging Reverted on std::function helpers Fixed double traversal of C++ objects Added heuristic for reserving layout and maps in FlatVariables Added back-assignment Fixed variant domain not being added at traversal Refactor to use VarLayout Deferred traversal Removed input eval Removed deep_make_opaque and deep_eval Comments Removed RecordingKey Moved assignment into record path Fixed vairable leaks on failures Removed fix comments Removed redundant code Renamed Hasher and Equal and added comments Comments Comments Cleanup
1 parent b656011 commit d47afff

25 files changed

+5499
-60
lines changed

drjit/__init__.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1878,6 +1878,102 @@ def binary_search(start, end, pred):
18781878

18791879
return start
18801880

1881+
def freeze(f):
1882+
'''
1883+
Decorator to freeze a function for replaying kernels without compilation.
1884+
1885+
This decorator wraps a function, and enables replaying it in order remove the
1886+
need for tracing python operations. The first time a frozen function is called,
1887+
it is executed regularly, while all operations performed are recorded. When the
1888+
frozen function is called again, with compatible arguments, the recorded
1889+
operations are replayed instead of launching the Python function. This saves on
1890+
the overhead that is inherent when tracing and compiling python operations into
1891+
kernels. The frozen function keeps a record of all previously recorded calls,
1892+
and associates them with the input layout. Since the kernels might be compiled
1893+
differently when changing certain attributes of variables, the frozen function
1894+
will be re-recorded if the inputs are no longer compatible with previous
1895+
recordings. We also track the layout of any container and the values of Python
1896+
types of the inputs to the frozen function. The following are the input
1897+
attributes which are tracked and might cause the function to be re-recorded if
1898+
changed.
1899+
1900+
- Python type of any variable or container
1901+
- Number of members in a container, such as the length of a list
1902+
- Key or field names, such as dictionary keys or dataclass field names
1903+
- The `Dr.Jit` type of any variable
1904+
- Whether a variable has size $1$
1905+
- Whether the memory, referenced by a variable is unaligned (only applies to
1906+
mapped NumPy arrays)
1907+
- Whether the variable has gradients enabled
1908+
- The sets of variables that have the same size
1909+
- The hash of any Python variable, that is not a Dr.Jit variable or a container
1910+
1911+
The width of a variable itself is not tracked, as we want to allow replaying of
1912+
kernels with different sizes. However, if multiple variables are scheduled and
1913+
evaluated at the same time, they can get compiled into the same kernel if their
1914+
size matches. Therefore, we track the size class of the variable, and re-record
1915+
the kernel if it changed. Because the size is not tracked it is leaked into the
1916+
frozen function context and we infer the size of kernels using a heuristic based
1917+
on the sizes of the input variables. This allows gathering from variables with
1918+
their width as an argument. However, using the size of a variable in more
1919+
complicated computations might lead to undefined behavior.
1920+
1921+
```python
1922+
y = dr.gather(type(x), x, dr.width(x)//2)
1923+
```
1924+
1925+
Similarly, calculating the mean of a variable relies on the number of entries,
1926+
which will be baked into the frozen function. To avoid this, we suggest
1927+
supplying the number of entries as a Dr.Jit literal in the arguments to the
1928+
function.
1929+
'''
1930+
import functools
1931+
1932+
class FrozenFunction:
1933+
def __init__(self, f) -> None:
1934+
self.ff = detail.FrozenFunction(f)
1935+
1936+
def __call__(self, *args, **kwargs):
1937+
return self.ff(*args, **kwargs)
1938+
1939+
@property
1940+
def n_recordings(self):
1941+
return self.ff.n_recordings
1942+
1943+
@property
1944+
def n_cached_recordings(self):
1945+
return self.ff.n_cached_recordings
1946+
1947+
def clear(self):
1948+
return self.ff.clear()
1949+
1950+
def __get__(self, obj, type=None):
1951+
if obj is None:
1952+
return self
1953+
else:
1954+
return FrozenMethod(self.ff, obj)
1955+
1956+
class FrozenMethod(FrozenFunction):
1957+
"""
1958+
A FrozenMethod currying the object into the __call__ method.
1959+
1960+
If the ``freeze`` decorator is applied to a method of some class, it has
1961+
to call the internal frozen function with the ``self`` argument. To this
1962+
end we implement the ``__get__`` method of the frozen function, to
1963+
return a ``FrozenMethod``, which holds a reference to the object.
1964+
The ``__call__`` method of the ``FrozenMethod`` then supplies the object
1965+
in addition to the arguments to the internal function.
1966+
"""
1967+
def __init__(self, ff, obj) -> None:
1968+
super().__init__(ff)
1969+
self.obj = obj
1970+
1971+
def __call__(self, *args, **kwargs):
1972+
return self.ff(self.obj, *args, **kwargs)
1973+
1974+
1975+
1976+
return functools.wraps(f)(FrozenFunction(f))
18811977

18821978
def assert_true(
18831979
cond,

include/drjit/array_traverse.h

Lines changed: 60 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
#pragma once
1717

18+
#include <type_traits>
1819
#define DRJIT_STRUCT_NODEF(Name, ...) \
1920
Name(const Name &) = default; \
2021
Name(Name &&) = default; \
@@ -140,6 +141,18 @@ namespace detail {
140141
using det_traverse_1_cb_rw =
141142
decltype(T(nullptr)->traverse_1_cb_rw(nullptr, nullptr));
142143

144+
template <typename T>
145+
using det_get = decltype(std::declval<T&>().get());
146+
147+
template <typename T>
148+
using det_const_get = decltype(std::declval<const T &>().get());
149+
150+
template<typename T>
151+
using det_begin = decltype(std::declval<T &>().begin());
152+
153+
template<typename T>
154+
using det_end = decltype(std::declval<T &>().begin());
155+
143156
inline drjit::string get_label(const char *s, size_t i) {
144157
auto skip = [](char c) {
145158
return c == ' ' || c == '\r' || c == '\n' || c == '\t' || c == ',';
@@ -180,10 +193,17 @@ template <typename T> auto labels(const T &v) {
180193
}
181194

182195
template <typename Value>
183-
void traverse_1_fn_ro(const Value &value, void *payload, void (*fn)(void *, uint64_t)) {
184-
(void) payload; (void) fn;
196+
void traverse_1_fn_ro(const Value &value, void *payload,
197+
void (*fn)(void *, uint64_t, const char *,
198+
const char *)) {
199+
(void) payload;
200+
(void) fn;
185201
if constexpr (is_jit_v<Value> && depth_v<Value> == 1) {
186-
fn(payload, value.index_combined());
202+
if constexpr(Value::IsClass)
203+
fn(payload, value.index_combined(), Value::CallSupport::Variant,
204+
Value::CallSupport::Domain);
205+
else
206+
fn(payload, value.index_combined(), "", "");
187207
} else if constexpr (is_traversable_v<Value>) {
188208
traverse_1(fields(value), [payload, fn](auto &x) {
189209
traverse_1_fn_ro(x, payload, fn);
@@ -198,14 +218,36 @@ void traverse_1_fn_ro(const Value &value, void *payload, void (*fn)(void *, uint
198218
is_detected_v<detail::det_traverse_1_cb_ro, Value>) {
199219
if (value)
200220
value->traverse_1_cb_ro(payload, fn);
221+
222+
} else if constexpr (is_detected_v<detail::det_begin, Value> &&
223+
is_detected_v<detail::det_end, Value>) {
224+
for (auto elem : value) {
225+
traverse_1_fn_ro(elem, payload, fn);
226+
}
227+
} else if constexpr (is_detected_v<detail::det_const_get, Value>) {
228+
const auto *tmp = value.get();
229+
traverse_1_fn_ro(tmp, payload, fn);
230+
} else if constexpr (is_detected_v<detail::det_traverse_1_cb_ro, Value *>) {
231+
value.traverse_1_cb_ro(payload, fn);
232+
} else {
233+
// static_assert(false, "Failed to traverse field!");
201234
}
202235
}
203236

204237
template <typename Value>
205-
void traverse_1_fn_rw(Value &value, void *payload, uint64_t (*fn)(void *, uint64_t)) {
206-
(void) payload; (void) fn;
238+
void traverse_1_fn_rw(Value &value, void *payload,
239+
uint64_t (*fn)(void *, uint64_t, const char *,
240+
const char *)) {
241+
(void) payload;
242+
(void) fn;
207243
if constexpr (is_jit_v<Value> && depth_v<Value> == 1) {
208-
value = Value::borrow((typename Value::Index) fn(payload, value.index_combined()));
244+
if constexpr(Value::IsClass)
245+
value = Value::borrow((typename Value::Index) fn(
246+
payload, value.index_combined(), Value::CallSupport::Variant,
247+
Value::CallSupport::Domain));
248+
else
249+
value = Value::borrow((typename Value::Index) fn(
250+
payload, value.index_combined(), "", ""));
209251
} else if constexpr (is_traversable_v<Value>) {
210252
traverse_1(fields(value), [payload, fn](auto &x) {
211253
traverse_1_fn_rw(x, payload, fn);
@@ -220,6 +262,18 @@ void traverse_1_fn_rw(Value &value, void *payload, uint64_t (*fn)(void *, uint64
220262
is_detected_v<detail::det_traverse_1_cb_rw, Value>) {
221263
if (value)
222264
value->traverse_1_cb_rw(payload, fn);
265+
} else if constexpr (is_detected_v<detail::det_begin, Value> &&
266+
is_detected_v<detail::det_end, Value>) {
267+
for (auto elem : value) {
268+
traverse_1_fn_rw(elem, payload, fn);
269+
}
270+
} else if constexpr (is_detected_v<detail::det_get, Value>) {
271+
auto *tmp = value.get();
272+
traverse_1_fn_rw(tmp, payload, fn);
273+
} else if constexpr (is_detected_v<detail::det_traverse_1_cb_rw, Value *>) {
274+
value.traverse_1_cb_rw(payload, fn);
275+
} else {
276+
// static_assert(false, "Failed to traverse field!");
223277
}
224278
}
225279

include/drjit/autodiff.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -971,7 +971,8 @@ NAMESPACE_BEGIN(detail)
971971
/// Internal operations for traversing nested data structures and fetching or
972972
/// storing indices. Used in ``call.h`` and ``loop.h``.
973973

974-
template <bool IncRef> void collect_indices_fn(void *p, uint64_t index) {
974+
template <bool IncRef>
975+
void collect_indices_fn(void *p, uint64_t index, const char *, const char *) {
975976
vector<uint64_t> &indices = *(vector<uint64_t> *) p;
976977
if constexpr (IncRef)
977978
index = ad_var_inc_ref(index);
@@ -983,7 +984,8 @@ struct update_indices_payload {
983984
size_t &pos;
984985
};
985986

986-
inline uint64_t update_indices_fn(void *p, uint64_t) {
987+
inline uint64_t update_indices_fn(void *p, uint64_t, const char *,
988+
const char *) {
987989
update_indices_payload &payload = *(update_indices_payload *) p;
988990
return payload.indices[payload.pos++];
989991
}

include/drjit/extra.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,11 @@ extern DRJIT_EXTRA_EXPORT bool ad_release_one_output(drjit::detail::CustomOpBase
247247
extern DRJIT_EXTRA_EXPORT void ad_copy_implicit_deps(drjit::vector<uint32_t> &,
248248
bool input);
249249

250+
/// Retrieve a list of ad indices, that are the target of edges, that have been
251+
/// postponed by the current scope
252+
extern DRJIT_EXTRA_EXPORT void ad_scope_postponed(drjit::vector<uint32_t> &);
253+
254+
250255
/// Kahan-compensated floating point atomic scatter-addition
251256
extern DRJIT_EXTRA_EXPORT void
252257
ad_var_scatter_add_kahan(uint64_t *target_1, uint64_t *target_2, uint64_t value,

include/drjit/python.h

Lines changed: 60 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,9 @@
5454
#include <drjit/math.h>
5555
#include <drjit-core/python.h>
5656
#include <nanobind/stl/array.h>
57+
#include <nanobind/intrusive/counter.h>
58+
#include <nanobind/nanobind.h>
59+
#include <drjit/traversable_base.h>
5760

5861
NAMESPACE_BEGIN(drjit)
5962
struct ArrayBinding;
@@ -1057,25 +1060,74 @@ template <typename T> void bind_all(ArrayBinding &b) {
10571060
// Expose already existing object tree traversal callbacks (T::traverse_1_..) in Python.
10581061
// This functionality is needed to traverse custom/opaque C++ classes and correctly
10591062
// update their members when they are used in vectorized loops, function calls, etc.
1060-
template <typename T, typename... Args> auto& bind_traverse(nanobind::class_<T, Args...> &cls) {
1063+
template <typename T, typename... Args> auto &bind_traverse(nanobind::class_<T, Args...> &cls)
1064+
{
10611065
namespace nb = nanobind;
1062-
struct Payload { nb::callable c; };
1066+
struct Payload {
1067+
nb::callable c;
1068+
};
1069+
1070+
static_assert(std::is_base_of_v<TraversableBase, T>);
10631071

1064-
cls.def("_traverse_1_cb_ro", [](const T *self, nb::callable c) {
1072+
cls.def("_traverse_1_cb_ro", [](T *self, nb::callable c) {
10651073
Payload payload{ std::move(c) };
1066-
self->traverse_1_cb_ro((void *) &payload, [](void *p, uint64_t index) {
1067-
((Payload *) p)->c(index);
1068-
});
1074+
self->traverse_1_cb_ro((void *) &payload,
1075+
[](void *p, uint64_t index, const char *variant,
1076+
const char *domain) {
1077+
((Payload *) p)->c(index, variant, domain);
1078+
});
10691079
});
10701080

10711081
cls.def("_traverse_1_cb_rw", [](T *self, nb::callable c) {
10721082
Payload payload{ std::move(c) };
1073-
self->traverse_1_cb_rw((void *) &payload, [](void *p, uint64_t index) {
1074-
return nb::cast<uint64_t>(((Payload *) p)->c(index));
1083+
self->traverse_1_cb_rw((void *) &payload, [](void *p, uint64_t index,
1084+
const char *variant,
1085+
const char *domain) {
1086+
return nb::cast<uint64_t>(
1087+
((Payload *) p)->c(index, variant, domain));
10751088
});
10761089
});
10771090

10781091
return cls;
10791092
}
10801093

1094+
inline void traverse_py_cb_ro(const TraversableBase *base, void *payload,
1095+
void (*fn)(void *, uint64_t, const char *variant,
1096+
const char *domain)) {
1097+
namespace nb = nanobind;
1098+
nb::handle self = base->self_py();
1099+
if (!self)
1100+
return;
1101+
1102+
auto detail = nb::module_::import_("drjit.detail");
1103+
nb::callable traverse_py_cb_ro =
1104+
nb::borrow<nb::callable>(nb::getattr(detail, "traverse_py_cb_ro"));
1105+
1106+
traverse_py_cb_ro(self,
1107+
nb::cpp_function([&](uint64_t index, const char *variant,
1108+
const char *domain) {
1109+
fn(payload, index, variant, domain);
1110+
}));
1111+
}
1112+
1113+
inline void traverse_py_cb_rw(TraversableBase *base, void *payload,
1114+
uint64_t (*fn)(void *, uint64_t, const char *,
1115+
const char *)) {
1116+
1117+
namespace nb = nanobind;
1118+
nb::handle self = base->self_py();
1119+
if (!self)
1120+
return;
1121+
1122+
auto detail = nb::module_::import_("drjit.detail");
1123+
nb::callable traverse_py_cb_rw =
1124+
nb::borrow<nb::callable>(nb::getattr(detail, "traverse_py_cb_rw"));
1125+
1126+
traverse_py_cb_rw(self,
1127+
nb::cpp_function([&](uint64_t index, const char *variant,
1128+
const char *domain) {
1129+
return fn(payload, index, variant, domain);
1130+
}));
1131+
}
1132+
10811133
NAMESPACE_END(drjit)

include/drjit/texture.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include <drjit/idiv.h>
1818
#include <drjit/jit.h>
1919
#include <drjit/tensor.h>
20+
#include <drjit/traversable_base.h>
2021

2122
#pragma once
2223

@@ -41,7 +42,7 @@ enum class CudaTextureFormat : uint32_t {
4142
Float16 = 1, /// Half precision storage format
4243
};
4344

44-
template <typename _Storage, size_t Dimension> class Texture {
45+
template <typename _Storage, size_t Dimension> class Texture : TraversableBase {
4546
public:
4647
static constexpr bool IsCUDA = is_cuda_v<_Storage>;
4748
static constexpr bool IsDiff = is_diff_v<_Storage>;
@@ -1386,6 +1387,9 @@ template <typename _Storage, size_t Dimension> class Texture {
13861387
WrapMode m_wrap_mode;
13871388
bool m_use_accel = false;
13881389
mutable bool m_migrated = false;
1390+
1391+
DR_TRAVERSE_CB(drjit::TraversableBase, m_value, m_shape_opaque,
1392+
m_inv_resolution);
13891393
};
13901394

13911395
NAMESPACE_END(drjit)

0 commit comments

Comments
 (0)