3
3
#include " autodiff.h"
4
4
#include " base.h"
5
5
#include " common.h"
6
+ #include " reduce.h"
6
7
#include " listobject.h"
7
8
#include " object.h"
8
9
#include " pyerrors.h"
@@ -139,7 +140,7 @@ static void log_layouts(const std::vector<Layout> &layouts, std::ostream &os,
139
140
auto tp_name = layout.type ? nb::type_name (layout.type ).c_str () : " None" ;
140
141
os << padding << " type = " << tp_name << std::endl;
141
142
os << padding << " num: " << layout.num << std::endl;
142
- os << padding << " flats : " << std::bitset<8 >(layout.flags ) << std::endl;
143
+ os << padding << " flags : " << std::bitset<8 >(layout.flags ) << std::endl;
143
144
os << padding << " index: " << layout.index << std::endl;
144
145
os << padding << " py_object: " << nb::str (layout.py_object ).c_str ()
145
146
<< std::endl;
@@ -678,8 +679,15 @@ void FlatVariables::traverse(nb::handle h, TraverseContext &ctx) {
678
679
if (s.is_tensor ) {
679
680
nb::handle array = s.tensor_array (h.ptr ());
680
681
681
- layout.py_object = shape (h);
682
- layout.index = width (array);
682
+ auto full_shape = nb::borrow<nb::tuple>(shape (h));
683
+
684
+ nb::list outer_shape;
685
+ if (full_shape.size () > 0 )
686
+ for (uint32_t i = 0 ; i < full_shape.size () - 1 ; i++) {
687
+ outer_shape.append (full_shape[i]);
688
+ }
689
+
690
+ layout.py_object = nb::tuple (outer_shape);
683
691
684
692
traverse (nb::steal (array), ctx);
685
693
} else if (s.ndim != 1 ) {
@@ -816,7 +824,18 @@ nb::object FlatVariables::construct() {
816
824
const ArraySupplement &s = supp (layout.type );
817
825
if (s.is_tensor ) {
818
826
nb::object array = construct ();
819
- nb::object tensor = layout.type (array, layout.py_object );
827
+
828
+ auto outer_shape = nb::borrow<nb::tuple>(layout.py_object );
829
+ auto last_dim = prod (shape (array), nb::none ())
830
+ .floor_div (prod (outer_shape, nb::none ()));
831
+
832
+ nb::list full_shape;
833
+ for (uint32_t i = 0 ; i < outer_shape.size (); i++) {
834
+ full_shape.append (outer_shape[i]);
835
+ }
836
+ full_shape.append (last_dim);
837
+
838
+ nb::object tensor = layout.type (array, nb::tuple (full_shape));
820
839
return tensor;
821
840
} else if (s.ndim != 1 ) {
822
841
auto result = nb::inst_alloc_zero (layout.type );
@@ -1464,7 +1483,7 @@ nb::object FrozenFunction::operator()(nb::args args, nb::kwargs kwargs) {
1464
1483
}
1465
1484
1466
1485
if (it == this ->recordings .end ()) {
1467
- #ifndef NDEBUG
1486
+ // #ifndef NDEBUG
1468
1487
if (this ->recordings .size () >= 1 ) {
1469
1488
jit_log (LogLevel::Info,
1470
1489
" Function input missmatch! Function will be retraced." );
@@ -1475,11 +1494,11 @@ nb::object FrozenFunction::operator()(nb::args args, nb::kwargs kwargs) {
1475
1494
std::ostringstream repr_prev;
1476
1495
repr_prev << *prev_key;
1477
1496
1478
- jit_log (LogLevel::Debug , " new key: %s" , repr.str ().c_str ());
1479
- jit_log (LogLevel::Debug , " old key: %s" ,
1497
+ jit_log (LogLevel::Warn , " new key: %s" , repr.str ().c_str ());
1498
+ jit_log (LogLevel::Warn , " old key: %s" ,
1480
1499
repr_prev.str ().c_str ());
1481
1500
}
1482
- #endif
1501
+ // #endif
1483
1502
1484
1503
{
1485
1504
// TODO: single traverse
0 commit comments