diff --git a/src/sage/misc/lazy_import.pyx b/src/sage/misc/lazy_import.pyx index 2eef113810c..d5fd17e54c7 100644 --- a/src/sage/misc/lazy_import.pyx +++ b/src/sage/misc/lazy_import.pyx @@ -73,7 +73,9 @@ AUTHOR: # lazy imports support old and true division. cimport cython -from cpython.object cimport PyObject_RichCompare +cimport cpython.type +from cpython.object cimport PyObject_RichCompare, newfunc, PyObject, PyTypeObject +from cpython.ref cimport Py_INCREF from cpython.number cimport PyNumber_TrueDivide, PyNumber_Power, PyNumber_Index cdef extern from *: @@ -996,6 +998,52 @@ cdef class LazyImport(): return False +cdef newfunc default_tp_new = (LazyImport).tp_new + + +cdef object lazyimport_tp_new(cpython.type.type t, PyObject* args, PyObject* kwds): + """ + We want code of the following form to work:: + + sage: from sage.misc.lazy_import import LazyImport + sage: TensorAlgebra = LazyImport('sage.algebras.tensor_algebra', 'TensorAlgebra') + sage: class Example(TensorAlgebra): pass + sage: Example(CombinatorialFreeModule(QQ, ["a", "b"])) + Tensor Algebra of Free module generated by {'a', 'b'} over Rational Field + + The reason why it is problematic is that ``class Example(TensorAlgebra): pass`` is essentially + syntactic sugar of the following:: + + sage: Example = type(TensorAlgebra)('Example', (TensorAlgebra,), {"__module__": "__main__", "__qualname__": "Example"}) + sage: Example + + sage: Example(CombinatorialFreeModule(QQ, ["a", "b"])) + Tensor Algebra of Free module generated by {'a', 'b'} over Rational Field + + But:: + + sage: type(TensorAlgebra) + + + To make it work, :meth:`LazyImport.__new__` must return something that is **not** a ``LazyImport** + instance. So we override ``__new__`` with this function. + """ + assert args != NULL + if len(args) >= 2 and isinstance((args)[1], str): + return default_tp_new(t, args, kwds) + assert t is LazyImport, t + future_cls_name, future_cls_parents, future_cls_attrs = args + assert len(future_cls_parents) >= 1 + cdef LazyImport l = future_cls_parents[0] + o = l.get_object() + assert kwds == NULL + return type(o).__new__(type(o), future_cls_name, + tuple(o if p is l else p for p in future_cls_parents), future_cls_attrs) + + +(LazyImport).tp_new = &lazyimport_tp_new + + def lazy_import(module, names, as_=None, *, at_startup=False, namespace=None, deprecation=None, feature=None):