Skip to content

Commit 587196c

Browse files
Allow for changes in the last dimension of tensor shapes
1 parent 1568ff1 commit 587196c

File tree

2 files changed

+57
-8
lines changed

2 files changed

+57
-8
lines changed

src/python/freeze.cpp

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include "autodiff.h"
44
#include "base.h"
55
#include "common.h"
6+
#include "reduce.h"
67
#include "listobject.h"
78
#include "object.h"
89
#include "pyerrors.h"
@@ -139,7 +140,7 @@ static void log_layouts(const std::vector<Layout> &layouts, std::ostream &os,
139140
auto tp_name = layout.type ? nb::type_name(layout.type).c_str() : "None";
140141
os << padding << "type = " << tp_name << std::endl;
141142
os << padding << "num: " << layout.num << std::endl;
142-
os << padding << "flats: " << std::bitset<8>(layout.flags) << std::endl;
143+
os << padding << "flags: " << std::bitset<8>(layout.flags) << std::endl;
143144
os << padding << "index: " << layout.index << std::endl;
144145
os << padding << "py_object: " << nb::str(layout.py_object).c_str()
145146
<< std::endl;
@@ -678,8 +679,15 @@ void FlatVariables::traverse(nb::handle h, TraverseContext &ctx) {
678679
if (s.is_tensor) {
679680
nb::handle array = s.tensor_array(h.ptr());
680681

681-
layout.py_object = shape(h);
682-
layout.index = width(array);
682+
auto full_shape = nb::borrow<nb::tuple>(shape(h));
683+
684+
nb::list outer_shape;
685+
if (full_shape.size() > 0)
686+
for (uint32_t i = 0; i < full_shape.size() - 1; i++) {
687+
outer_shape.append(full_shape[i]);
688+
}
689+
690+
layout.py_object = nb::tuple(outer_shape);
683691

684692
traverse(nb::steal(array), ctx);
685693
} else if (s.ndim != 1) {
@@ -816,7 +824,18 @@ nb::object FlatVariables::construct() {
816824
const ArraySupplement &s = supp(layout.type);
817825
if (s.is_tensor) {
818826
nb::object array = construct();
819-
nb::object tensor = layout.type(array, layout.py_object);
827+
828+
auto outer_shape = nb::borrow<nb::tuple>(layout.py_object);
829+
auto last_dim = prod(shape(array), nb::none())
830+
.floor_div(prod(outer_shape, nb::none()));
831+
832+
nb::list full_shape;
833+
for (uint32_t i = 0; i < outer_shape.size(); i++) {
834+
full_shape.append(outer_shape[i]);
835+
}
836+
full_shape.append(last_dim);
837+
838+
nb::object tensor = layout.type(array, nb::tuple(full_shape));
820839
return tensor;
821840
} else if (s.ndim != 1) {
822841
auto result = nb::inst_alloc_zero(layout.type);
@@ -1464,7 +1483,7 @@ nb::object FrozenFunction::operator()(nb::args args, nb::kwargs kwargs) {
14641483
}
14651484

14661485
if (it == this->recordings.end()) {
1467-
#ifndef NDEBUG
1486+
// #ifndef NDEBUG
14681487
if (this->recordings.size() >= 1) {
14691488
jit_log(LogLevel::Info,
14701489
"Function input missmatch! Function will be retraced.");
@@ -1475,11 +1494,11 @@ nb::object FrozenFunction::operator()(nb::args args, nb::kwargs kwargs) {
14751494
std::ostringstream repr_prev;
14761495
repr_prev << *prev_key;
14771496

1478-
jit_log(LogLevel::Debug, "new key: %s", repr.str().c_str());
1479-
jit_log(LogLevel::Debug, "old key: %s",
1497+
jit_log(LogLevel::Warn, "new key: %s", repr.str().c_str());
1498+
jit_log(LogLevel::Warn, "old key: %s",
14801499
repr_prev.str().c_str());
14811500
}
1482-
#endif
1501+
// #endif
14831502

14841503
{
14851504
// TODO: single traverse

tests/test_freeze.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2783,3 +2783,33 @@ def func():
27832783
ref = func()
27842784

27852785
assert dr.allclose(res, ref)
2786+
2787+
2788+
@pytest.test_arrays("float32, jit, shape=(*)")
2789+
def test76_changing_literal_width_holder(t):
2790+
2791+
class MyHolder:
2792+
DRJIT_STRUCT = {"lit": t}
2793+
def __init__(self, lit):
2794+
self.lit = lit
2795+
2796+
def func(x: t, lit: MyHolder):
2797+
return x + 1
2798+
2799+
# Note: only fails with auto_opaque=True
2800+
frozen = dr.freeze(func, warn_recording_count=3)
2801+
2802+
n = 10
2803+
for i in range(n):
2804+
holder = MyHolder(dr.zeros(dr.tensor_t(t), (i+1) * 10))
2805+
x = holder.lit + 0.5
2806+
dr.make_opaque(x)
2807+
2808+
res = frozen(x, holder)
2809+
ref = func(x, holder)
2810+
2811+
assert dr.allclose(ref, res)
2812+
2813+
assert frozen.n_recordings == 1
2814+
2815+

0 commit comments

Comments
 (0)