diff --git a/CMakeLists.txt b/CMakeLists.txt index e9070f1a..3cfc2dae 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -108,7 +108,7 @@ if (DRJIT_ENABLE_JIT) set_target_properties(nanothread PROPERTIES ${DRJIT_OUTPUT_DIRECTORY}) endif() -mark_as_advanced(NANOTHREAD_ENABLE_TESTS) +mark_as_advanced(NANOTHREAD_ENABLE_TESTS NANOTHREAD_STATIC) mark_as_advanced(DRJIT_CORE_ENABLE_TESTS) mark_as_advanced(NB_TEST NB_TEST_SHARED_BUILD NB_TEST_STABLE_ABI NB_USE_SUBMODULE_DEPS NB_TEST_SANITZE NB_CREATE_INSTALL_RULES nanobind_DIR) mark_as_advanced(NB_TEST_CUDA NB_TEST_FREE_THREADED NB_TEST_SANITIZERS_ASAN NB_TEST_SANITIZERS_TSAN NB_TEST_SANITIZERS_UBSAN) diff --git a/docs/autodiff.rst b/docs/autodiff.rst index 805ebfea..e3eb7b3a 100644 --- a/docs/autodiff.rst +++ b/docs/autodiff.rst @@ -427,8 +427,8 @@ Dr.Jit how a particular operation should be differentiated. Reasons for this may include: - The automatic differentiation backend cannot keep track of computation - performed outside of Dr.Jit (e.g. using a highly optimized :ref:`CUDA kernel - `). In this case, review the section on :ref:`interoperability + performed outside of Dr.Jit (e.g. using custom CUDA kernels). In this case, + review the section on :ref:`interoperability `, since it presents a potentially simpler solution. - The derivative may admit a simplified analytic expression that is superior to diff --git a/docs/changelog.rst b/docs/changelog.rst index 6f24f3d3..63818329 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -348,7 +348,7 @@ Here is what's new: ⚠️ Compatibility ⚠️ ------------------- +------------------- - **Symbolic loop syntax**: the old "recorded loop" syntax is no longer supported. Existing code will need adjustments to use diff --git a/docs/coop_vec.rst b/docs/coop_vec.rst new file mode 100644 index 00000000..0ce66789 --- /dev/null +++ b/docs/coop_vec.rst @@ -0,0 +1,317 @@ +.. py:currentmodule:: drjit + +.. cpp:namespace:: drjit + +.. _coop_vec: + +Cooperative vectors +=================== + +*Cooperative vectors* are a `new API +`__ +for evaluating matrix-vector products in certain types of GPU workloads. They +are designed to handle cases, where each thread of a parallel program needs +to multiply a vector by a reasonably small matrix (e.g., 64x64 or fewer +entries). By working together, the threads can perform these multiplications +more efficiently, which is why the approach is called *cooperative*. + +Cooperative vectors are especially useful for evaluating small `multilayer +perceptrons `__ (MLPs) +within larger programs while fully *fusing* all steps of the process into a +single kernel. Other workloads that heavily rely on matrix-vector products may +benefit as well. + +Dr.Jit supports cooperative vectors on both of its backends: + +- On **NVIDIA GPUs (Turing or newer)**, cooperative vectors map to the OptiX + `cooperative vector API + `__, + leveraging built-in `tensor cores + `__ for acceleration. + Driver version R570 or newer is required to use this feature. + +- On the **CPU (LLVM) backend**, compilation of cooperative vector operations + targets the available instruction set extensions (AVX512, NEON, etc.). + +Code snippets in the remainder of this section assume the following include +directives: + +.. code-block:: python + + import drjit as dr + import drjit.nn as nn + from drjit.auto.ad import Float16, TensorXf16 + +Motivation +---------- + +The cooperative vector API is available via the :py:mod:`drjit.nn` submodule. +Below is an example demonstrating how to use it to perform a matrix +multiplication. + +.. code-block:: python + + # Matrix shape + m, n = 3, 16 + + # Create a random matrix + offset + A = dr.normal(TensorXf, (m, n)) + b = dr.rand(TensorXf, m) + + # Pack 'A' and 'b' into a buffer with an optimal layout + buffer, A_view, b_view = nn.pack(A, b) + + # Create a cooperative vector + x = nn.CoopVec(... 16 values ...) + + # Evaluate A @ x + b + v_out = nn.matvec(A_view, v_in, b_view) + + # Unpack the resulting cooperative vector + x, y, z = v_out + +This involves the following steps: + +- Initializing matrix data and packing it into an optimized memory layout using + :py:func:`nn.pack() `. + +- Constructing a :py:class:`nn.CoopVec` containing the inputs to the matrix + multiplication.inputs. + +- Performing one or more matrix-vector multiplications and other arithmetic, + while keeping the state in cooperative vector form. + +- Unpacking the final cooperative vector into regular Dr.Jit arrays. + +Cooperative vectors +------------------- + +The central type of this API is the *cooperative vector* class +:py:class:`nn.CoopVec`. This is a dynamically sized vector with uniformly +typed elements. + +Unlike regular Dr.Jit arrays (e.g. :py:class:`drjit.cuda.ArrayXf`), cooperative +vectors *do not allow indexed element access*. For example, the following +operation raises an exception: + +.. code-block:: pycon + + >>> vec = nn.CoopVec(Float16(1), Float16(2)) + >>> vec[1] + Traceback (most recent call last): + File "", line 1, in + TypeError: 'drjit.nn.CoopVec' object is not subscriptable + +This restriction exists because the compiler may arbitrarily distribute +cooperative vector components across threads for efficiency. Allowing direct +indexing would interfere with this optimization. + +The :py:class:`drjit.nn.CoopVec` constructor accepts an arbitrary sequence +of :ref:`PyTrees ` containing Dr.Jit array and Python scalars and +flattens them into a cooperative vector: + +.. code-block:: python + + vec = nn.CoopVec( # Construct a 4D vector + Float16(1), + 3.0, + Array2f(4, 5) + ) + +Use the standard Python unpacking syntax to turn cooperative vectors back into +their components: + +.. code-block:: python + + x, y, z = vec # Unpack a cooperative 3D vector + x, y, *extra = vec # Unpack first 2 components, put rest into 'extra' + +The same syntax can also be used to concatenate vectors: + +.. code-block:: python + + vec_3 = nn.CoopVec(*vec_1, *vec_2) + +Cooperative vectors can also be converted into nested arrays, tensors, or +Python lists: + +.. code-block:: python + + vec_arr = Array3f(vec) + vec_ten = TensorXf(vec) + vec_lst = list(vec) + +Cooperative vectors are compatible with Dr.Jit's symbolic tracing +infrastructure and may be used as state variables in +:py:func:`drjit.while_loop` and :py:func:`drjit.if_stmt`. + +Arithmetic +^^^^^^^^^^ + +Cooperative vectors support a restricted set of arithmetic operations: + +- Elementary arithmetic operations: ``+``, ``-``, ``*`` (but no division) +- :py:func:`dr.fma() `, +- :py:func:`dr.minimum() `, :py:func:`dr.maximum() `, +- :py:func:`dr.log2() `, :py:func:`dr.exp2() `, +- :py:func:`dr.tanh() `, +- :py:func:`dr.step() `. +- :py:func:`nn.matvec() ` + +These operations directly map to hardware-optimized operations on CUDA/OptiX. +Operations outside of this set can be realized via unpacking/repacking, e.g.: + +.. code-block:: + + x : nn.CoopVec = ... + y = nn.CoopVec(dr.sin(v) for v in x) + +However, this may degrade performance. It is best to keep cooperative vectors +in their opaque layout whenever possible. + +Arithmetic operations may mix cooperative vectors and regular Dr.Jit arrays or +Python scalars, which will undergo implicit broadcasting. + +.. code-block:: + + x: nn.CoopVec[dr.cuda.Float16] = ... + y: dr.cuda.Float16 = ... + z = dr.maximum(x, 0) + y + +.. _matrix_views: + +Matrix views +------------ + +Input matrices and bias vectors should generally be converted into a +hardware-dependent layout to improve performance compared to the default +row-major representation (also, many operations raise exceptions on the +OptiX/CUDA backend when matrices are not in such an optimal layout). + +The function :py:func:`nn.pack() ` performs this conversion and +furthermore packs data into a shared buffer for optimal efficiency. The +function takes an arbitrary sequence of :ref:`PyTrees ` as input and +returns a result with the same structure. + +.. code-block:: python + + A: TensorXf = ... + b: Float = ... + A_view, b_view = nn.pack(A, b, layout='inference') + +Every Dr.Jit array or tensor will be replaced by a +:py:class:`drjit.nn.MatrixView`, which is a thin pointer into a shared buffer +annotated with layout and type metadata. The function can generate optimal +memory layouts for either *inference* (the default) and *training*. You must +specify ``layout='training'`` if you wish to differentiate matrix +multiplication in reverse mode. + +Following this step, ``A`` and ``b`` have been merged into ``buffer``, and +``A_view`` and ``b_view`` encode the offset and layout within this larger +buffer. Matrix views *cannot* be used in arithmetic expressions and are best +thought of as opaque handles. They only exist to describe the input of the +matrix-vector multiplication operation explained next. + +Two other view-related operations be useful in certain situations, please +see the linked documentation for details. + +- :py:func:`drjit.nn.unpack` converts optimal-layout data back into a row-major layout. +- :py:func:`drjit.nn.view` creates row-major views. + +Matrix-vector products +---------------------- + +The main purpose of cooperative vectors is the matrix-vector multiplication +operation :py:func:`nn.matvec() `: + +.. code-block:: python + + y = nn.matvec(A, x, b) # Compute y = A @ x + b + +Here, + +- ``A`` and ``b`` are *views* (:py:class:`nn.MatrixView`) created by + :py:func:`nn.pack() ` or :py:func:`nn.view() + `. +- ``x`` and ``y`` are cooperative vectors. They are interpreted as *column + vectors*, i.e., ``y = A[:, 0] * x[0] + A[:, 1] * x[1] + ... + b``. +- the ``b`` term is optional. + +The function also accepts an optional ``transpose=True`` parameter to compute +:math:`A^Tx + b`. + +The standard Python ``A @ x`` and ``A.T @ x`` matrix multiplication syntax +works as well. However, if your computation requires the addition of a ``b`` +vector, prefer :py:func:`nn.matvec() ` over this syntax, since +it merges both steps into a single operation. + +Differentiation +--------------- + +Cooperative vectors support automatic differentiation. Simply pack variables +with tracked gradients into cooperative vectors---the system will then +propagate derivatives through subsequent operations. Here is an example: + +.. code-block:: python + + # Differentiable input + a = Array2f16(..) + dr.enable_grad(a) + + # Differentiable matrix + bias vector + buffer, A_view, b_view = nn.pack(A, b) + dr.enable_grad(buffer) + + # Pack grad-enabled variables into a cooperative vector + x = nn.CoopVec(a) + + # Differentiable matrix-vector multiplication + y = dr.matvec(A_view, x, b_view) + + r0, r1 = y # Unpack + loss = r0**2 + r1**2 # Continue calculation and .. + dr.backward_from(loss) # .. eventually backpropagate + +Specific views or cooperative vectors can also be detached via +:py:func:`drjit.detach()` to inhibit gradient propagation, e.g.: + +.. code-block:: python + + y = nn.matvec(A_view, dr.detach(x), dr.detach(b_view)) + +Note that the conversion functions :py:func:`nn.pack() ` and +:py:func:`nn.unpack() ` are *not differentiable*. This is +intentional: to train a neural network, convert the initial coefficient values +into training-optimal layout and optimize this representation directly. Doing +so is more efficient than changing layouts twice in every optimization step +(once for the weights and once for their derivatives). + +The following AD operations recognize :py:func:`nn.CoopVec +` and :py:func:`nn.MatrixView ` objects: + +- :py:func:`grad_enabled`, :py:func:`enable_grad`, :py:func:`disable_grad`. +- :py:func:`detach`. + +Performance considerations +-------------------------- + +- **CUDA/OptiX** backend: + + - :py:func:`nn.matvec() ` currently requires 16-bit + floating point arguments. FP8 formats may be added in the future. + + - Tensor cores work with 8x8 and 16x16 blocks. Matrices, whose row or column + counts are not a multiples of 8 or 16 will be zero-padded internally. There + is no performance benefit in working with such intermediate sizes. + +- **LLVM** backend: + + - There is no difference between row-major and training/inference-optimal + layouts on the CPU. However, using :py:func:`nn.pack() + ` is still recommended, since packing multiple arrays + into a shared buffer has a small performance benefit. + + - On Intel-compatible processors, using half precision cooperative vectors is + not recommended. FP16 matrix multiplication requires ``AVX512FP16``, an + extension not yet available on consumer CPUs as of 2025. Without this + extension, FP16 computation involves many costly FP16 ↔ FP32 roundtrips. diff --git a/docs/index.rst b/docs/index.rst index 0b5ac4fe..e10e6611 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -46,6 +46,8 @@ public API. bench cpp textures + coop_vec + nn faq .. toctree:: diff --git a/docs/misc.rst b/docs/misc.rst index 9321e6c9..07ccad5d 100644 --- a/docs/misc.rst +++ b/docs/misc.rst @@ -529,7 +529,7 @@ resolve at a later point. So here, we have - ``SelfCp``: a forward reference to ``drjit.llvm.ad._Array2fCp`` (more on this shortly), - ``ValT``: :py:class:`drjit.llvm.ad.Float`, - ``ValCpT``: a forward reference to ``drjit.llvm.ad._FloatCp`` (more on this shortly), -- ``RedT``: :py:class`drjit.llvm.ad.Float`, +- ``RedT``: :py:class:`drjit.llvm.ad.Float`, - ``PlainT``: :py:class:`drjit.llvm.ad.Array2f`, and - ``MaskT``: :py:class:`drjit.llvm.ad.Array2b`. diff --git a/docs/nn.rst b/docs/nn.rst new file mode 100644 index 00000000..8e90b02c --- /dev/null +++ b/docs/nn.rst @@ -0,0 +1,136 @@ +.. py:currentmodule:: drjit.nn + +.. _neural_nets: + +Neural Networks +=============== + +Dr.Jit's neural network infrastructure builds on :ref:`cooperative vectors +`. Please review their documentation before reading this section. + +The module :py:mod:`drjit.nn` provides convenient modular abstractions to +construct, evaluate, and optimize neural networks. Their design resembles the +PyTorch `nn.Module +`__ classes. +The Dr.Jit :py:class:`nn.Module ` class takes a cooperative vector as input +and produces another cooperative vector. Modules can be chained to form longer +sequential pipelines. + +.. warning:: + + The neural network classes are experimental and subject to change in future + versions of Dr.Jit. + +List +---- + +The set of neural network module currently includes: + +- Sequential evaluation of a list of models: :py:class:`nn.Sequential `. + +- Linear and affine layers: :py:class:`nn.Linear `. + +- Encoding layers: :py:class:`nn.SinEncode `, :py:class:`nn.TriEncode `. + +- Activation functions and other nonlinear transformations: :py:class:`nn.ReLU `, :py:class:`nn.LeakyReLU `, + :py:class:`nn.Exp `, :py:class:`nn.Exp2 `, :py:class:`nn.Tanh `. + +- Miscellaneous: :py:class:`nn.Cast `, :py:class:`nn.ScaleAdd `. + +Example +------- + +Below is a fully worked out example demonstrating how to use it to declare and +optimize a small `multilayer perceptron +`__ (MLP). This network +implements a 2D neural field (right) that we then fit to a low-resolution image of `The +Great Wave off Kanagawa +`__ (left). + +.. image:: https://rgl.s3.eu-central-1.amazonaws.com/media/uploads/wjakob/2024/06/coopvec-screenshot.png + :width: 600 + :align: center + +The optimization uses the *Adam* optimizer (:py:class:`dr.opt.Adam +`) optimizer and a *gradient scaler* +(:py:class:`dr.opt.GradScaler `) for adaptive +mixed-precision training. + +.. code-block:: python + + from tqdm.auto import tqdm + import imageio.v3 as iio + import drjit as dr + import drjit.nn as nn + from drjit.opt import Adam, GradScaler + from drjit.auto.ad import Texture2f, TensorXf, TensorXf16, Float16, Float32, Array2f, Array3f + + # Load a test image and construct a texture object + ref = TensorXf(iio.imread("https://rgl.s3.eu-central-1.amazonaws.com/media/uploads/wjakob/2024/06/wave-128.png") / 256) + tex = Texture2f(ref) + + # Ensure consistent results when re-running the following + dr.seed(0) + + # Establish the network structure + net = nn.Sequential( + nn.TriEncode(16, 0.2), + nn.Cast(Float16), + nn.Linear(-1, -1, bias=False), + nn.LeakyReLU(), + nn.Linear(-1, -1, bias=False), + nn.LeakyReLU(), + nn.Linear(-1, -1, bias=False), + nn.LeakyReLU(), + nn.Linear(-1, 3, bias=False), + nn.Exp() + ) + + # Instantiate the network for a specific backend + input size + net = net.alloc(TensorXf16, 2) + + # Convert to training-optimal layout + weights, net = nn.pack(net, layout='training') + print(net) + + # Optimize a single-precision copy of the parameters + opt = Adam(lr=1e-3, params={'weights': Float32(weights)}) + + # This is an adaptive mixed-precision (AMP) optimization, where a half + # precision computation runs within a larger single-precision program. + # Gradient scaling is required to make this numerically well-behaved. + scaler = GradScaler() + + res = 256 + + for i in tqdm(range(40000)): + # Update network state from optimizer + weights[:] = Float16(opt['weights']) + + # Generate jittered positions on [0, 1]^2 + t = dr.arange(Float32, res) + p = (Array2f(dr.meshgrid(t, t)) + dr.rand(Array2f, (2, res * res))) / res + + # Evaluate neural net + L2 loss + img = Array3f(net(nn.CoopVec(p))) + loss = dr.squared_norm(tex.eval(p) - img) + + # Mixed-precision training: take suitably scaled steps + dr.backward(scaler.scale(loss)) + scaler.step(opt) + + # Done optimizing, now let's plot the result + t = dr.linspace(Float32, 0, 1, res) + p = Array2f(dr.meshgrid(t, t)) + img = Array3f(net(nn.CoopVec(p))) + + # Convert 'img' with shape 3 x (N*N) into a N x N x 3 tensor + img = dr.reshape(TensorXf(img, flip_axes=True), (res, res, 3)) + + import matplotlib.pyplot as plt + fig, ax = plt.subplots(1, 2, figsize=(10,5)) + ax[0].imshow(ref) + ax[1].imshow(dr.clip(img, 0, 1)) + fig.tight_layout() + plt.show() + diff --git a/docs/reference.rst b/docs/reference.rst index 6d369a13..a71737bd 100644 --- a/docs/reference.rst +++ b/docs/reference.rst @@ -146,6 +146,8 @@ Miscellaneous operations .. autofunction:: binary_search .. autofunction:: make_opaque .. autofunction:: copy +.. autofunction:: linear_to_srgb +.. autofunction:: srgb_to_linear Just-in-time compilation ------------------------ @@ -276,6 +278,7 @@ Standard mathematical functions .. autofunction:: sign .. autofunction:: copysign .. autofunction:: mulsign +.. autofunction:: step Operations for vectors and matrices ----------------------------------- @@ -631,6 +634,8 @@ Low-level bits .. py:currentmodule:: drjit.detail .. autofunction:: set_leak_warnings .. autofunction:: leak_warnings +.. autofunction:: llvm_version +.. autofunction:: cuda_version .. py:currentmodule:: drjit Typing @@ -690,7 +695,6 @@ gradient-based optimization and adaptive mixed-precision training. .. automethod:: __delitem__ .. automethod:: __contains__ .. automethod:: __len__ - .. automethod:: update .. automethod:: keys .. automethod:: values .. automethod:: items @@ -713,3 +717,121 @@ gradient-based optimization and adaptive mixed-precision training. .. automethod:: step .. automethod:: scale .. automethod:: unscale + +.. _coop_vec_ref: + +Cooperative Vectors +------------------- + +.. py:module:: drjit.nn + +The :py:mod:`drjit.nn` module provides infrastructure to implement small +neural networks and revolves around the notion of *cooperative vectors* that +facilitate code generation of matrix-vector products. Please see the separate +:ref:`documentation section ` for an introduction. + +.. autoclass:: CoopVec + + .. automethod:: __init__ + .. automethod:: __add__ + .. automethod:: __sub__ + .. automethod:: __mul__ + .. automethod:: __len__ + .. automethod:: __repr__ + + .. property:: index + :type: int + + Stores the Dr.Jit variable index of the cooperative vector. + + .. property:: type + :type: type[drjit.ArrayBase] + + Stores the element type + +.. autoclass:: MatrixView + + .. automethod:: __getitem__ + + .. property:: dtype + :type: drjit.VarType + + Scalar type underlying the view. + + .. property:: shape + :type: tuple[int, int] + + Number of rows/columns. Vectors are stored as matrices with one column. + + .. property:: layout + :type: MatrixLayout + + One of several possible matrix layouts (training/inference-optimal and + row-major). + + .. property:: stride + :type: int + + Row stride (in # of elements) + + .. property:: size + :type: int + + Total number of elements + + .. property:: transpose + :type: bool + + The ``MatrixView.T`` property flips this flag (all other + values stay unchanged). + + .. property:: buffer + :type: drjit.ArrayBase + + The underlying buffer, which may contain additional matrices/vectors + besides the data referenced by the :py:class:`MatrixView`. + + .. property:: T + :type: MatrixView + + Return a transposed view. + + .. property:: grad + :type: MatrixView + + Return an analogous view of the gradient. + +.. autofunction:: view +.. autofunction:: pack +.. autofunction:: unpack +.. autofunction:: matvec +.. autofunction:: cast + +Neural Networks +--------------- + +Besides :ref:`cooperative vector classes `, the +:py:mod:`drjit.nn` module also provides convenient abstractions to declare, +evaluate, and train networks. Please see the separate :ref:`documentation +section ` for an introduction. + +.. autoclass:: Model + + .. automethod:: __call__ + .. automethod:: alloc + +.. autoclass:: Sequential + + .. automethod:: __len__ + .. automethod:: __getitem__ + +.. autoclass:: Linear +.. autoclass:: ReLU +.. autoclass:: LeakyReLU +.. autoclass:: SinEncode +.. autoclass:: TriEncode +.. autoclass:: Exp +.. autoclass:: Exp2 +.. autoclass:: Tanh +.. autoclass:: Cast +.. autoclass:: ScaleAdd diff --git a/docs/what.rst b/docs/what.rst index e76af0f1..04726ea0 100644 --- a/docs/what.rst +++ b/docs/what.rst @@ -22,9 +22,11 @@ Using Dr.Jit involves two steps: **That's it**. It doesn't do much, but it does this *very efficiently*. Perhaps the most significant difference to the majority of existing tools is -that Dr.Jit is *not* a machine learning library. Its sweet spot are non-neural -programs characterized by *embarrassing parallelism*---that is to say, programs -with large data-parallel regions. A good example of this are `Monte Carlo +that Dr.Jit is *not primarily* a machine learning library. While it does +provide support for neural network :ref:`evaluation and training `, +its sweet spot are non-neural programs characterized by *embarrassing +parallelism*---that is to say, programs with large data-parallel regions. A +good example of this are `Monte Carlo `__ methods with their parallel sample evaluation (indeed, the reason why this project was originally created was to provide the foundation of `Mitsuba 3 diff --git a/drjit/__init__.py b/drjit/__init__.py index f0bab723..0b96ef7d 100644 --- a/drjit/__init__.py +++ b/drjit/__init__.py @@ -2269,7 +2269,7 @@ def upsample(t, shape=None, scale_factor=None): _rand_seed : int = 0 -def seed(value: int): +def seed(value: int) -> None: """ Reset the seed value that is used for pseudorandom number generation. @@ -2647,6 +2647,44 @@ def assert_equal( **kwargs, ) +def srgb_to_linear(x: ArrayT, clip_range: bool = True) -> ArrayT: + """ + Convert a sRGB gamma-corrected intensity value on the interval [0, 1] into + a linear intensity value on the interval [0, 1]. + + Values outside of the range [0, 1] are clipped by default. You may specify + `clip_range=False` to avoid this step if your data is already guranteed to be in + this range. + """ + + if clip_range: + x = clip(x, 0, 1) + + return select( + x < 0.04045, + x / 12.92, + fma(x, 1 / 1.055, 0.055 / 1.055) ** 2.4 + ) + +def linear_to_srgb(x: ArrayT, clip_range: bool = True) -> ArrayT: + """ + Convert a linear intensity value on the interval [0, 1] to into a sRGB + value by applying the underlying gamma correction curve. + + Values outside of the range [0, 1] are clipped by default. You may specify + `clip_range=False` to avoid this step if your data is already guranteed to be in + this range. + """ + + if clip_range: + x = clip(x, 0, 1) + + return select( + x < 0.0031308, + x * 12.92, + fma(1.055, x ** (1.0 / 2.4), -0.055) + ) + newaxis = None diff --git a/drjit/nn.py b/drjit/nn.py new file mode 100644 index 00000000..081b5591 --- /dev/null +++ b/drjit/nn.py @@ -0,0 +1,505 @@ +from __future__ import annotations +import drjit +import sys + +if sys.version_info < (3, 11): + from typing_extensions import Tuple, Sequence, Union, Type, TypeAlias, Optional, Any +else: + from typing import Tuple, Sequence, Union, Type, TypeAlias, Optional, Any + +# Import classes/functions from C++ extension +MatrixView = drjit.detail.nn.MatrixView +CoopVec = drjit.detail.nn.CoopVec +pack = drjit.detail.nn.pack +unpack = drjit.detail.nn.unpack +matvec = drjit.detail.nn.matvec +view = drjit.detail.nn.view +cast = drjit.detail.nn.cast +T = drjit.detail.nn.T + +TensorOrViewOrNone: TypeAlias = Union[ + drjit.ArrayBase, + MatrixView, + None +] + +class Module: + """ + This is the base class of a modular set of operations that make + the specification of neural network architectures more convenient. + + Module subclasses are :ref:`PyTrees `, which means that various + Dr.Jit operations can automatically traverse them. + + Constructing a neural network generally involves the following pattern: + + .. code-block:: + + # 1. Establish the network structure + net = nn.Sequential( + nn.Linear(-1, 32, bias=False), + nn.ReLU(), + nn.Linear(-1, 3) + ) + + # 2. Instantiate the network for a specific backend + input size + net = net.alloc(TensorXf16, 2) + + # 3. Pack coefficients into a training-optimal layout + coeffs, net = nn.pack(net, layout='training') + + Network evaluation expects a :ref:`cooperative vector ` as input + (i.e., ``net(nn.CoopVec(...))``) and returns another cooperative vector. + The ``coeffs`` buffer contains all weight/bias data in training-optimal + format and can be optimized, which will directly impact the packed network + ``net`` that references this buffer. + """ + def __call__(self, arg: CoopVec, /) -> CoopVec: + """ + Evaluate the model with an input cooperative vector and return the result. + """ + raise NotImplementedError(f"{type(self).__name__}.__call__() implementation is missing.") + + def _alloc(self, dtype: Type[drjit.ArrayBase], size: int, /) -> Tuple[Module, int]: + """ + Internal method used to propagate argument sizes and allocate weight + storage of all NN modules. + + The method takes to parameters as input: a weight storage type + ``dtype`` (e.g., :py:class:`drjit.cuda.ad.TensorXf16`) and ``size``, + the number of input arguments of the module. The function returns a + potentially new module instance with allocated weights, plus the number + of outputs. + """ + return self, size + + def alloc(self, dtype: Type[drjit.ArrayBase], size: int = -1) -> Module: + """ + Returns a new instance of the model with allocated weights. + + This function expects a suitable tensor ``dtype`` (e.g. + :py:class:`drjit.cuda.ad.TensorXf16` or + :py:class:`drjit.llvm.ad.TensorXf`) that will be used to store the + weights on the device. + + If the model or one of its sub-models is automatically sized (e.g., + ``input_features=-1`` in :py:class:`drjit.nn.Linear`), the final + network configuration may ambiguous and an exception will be raised. + Specify the optional ``size`` parameter in such cases to inform the + allocation about the size of the input cooperative vector. + """ + return self._alloc(dtype, size)[0] + + def __repr__(self) -> str: + return f"{type(self).__name__}()" + +class Sequential(Module, Sequence[Module]): + """ + This model evaluates provided arguments ``arg[0]``, ``arg[1]``, ..., in sequence. + """ + DRJIT_STRUCT = { 'layers' : tuple } + + layers: tuple[Module, ...] + + def __init__(self, *args: Module): + self.layers = args + + def __call__(self, arg: CoopVec, /) -> CoopVec: + for l in self.layers: + arg = l(arg) + return arg + + def _alloc(self, dtype: Type[drjit.ArrayBase], size: int = -1, /) -> Tuple[Module, int]: + result = [] + for l in self.layers: + l_new, size = l._alloc(dtype, size) + result.append(l_new) + return Sequential(*result), size + + def __len__(self): + """Return the number of contained models""" + return len(self.layers) + + def __getitem__(self, index: int, /) -> Module: # type: ignore + """Return the model at position ``index``""" + return self.layers[index] + + def __repr__(self) -> str: + s = 'Sequential(\n' + n = len(self.layers) + for i in range(n): + s += ' ' + repr(self.layers[i]).replace('\n', '\n ') + if i + 1 < n: + s += ',' + s += '\n' + s += ')' + return s + +class ReLU(Module): + r""" + ReLU (rectified linear unit) activation function. + + This model evaluates the following expression: + + .. math:: + + \mathrm{ReLU}(x) = \mathrm{max}\{x, 0\}. + + """ + + DRJIT_STRUCT = { } + def __call__(self, arg: CoopVec, /) -> CoopVec: + return drjit.maximum(arg, 0) + +class LeakyReLU(Module): + r""" + "Leaky" ReLU (rectified linear unit) activation function. + + This model evaluates the following expression: + + .. math:: + + \mathrm{LeakyReLU}(x) = \begin{cases} + x,&\mathrm{if}\ x\ge 0,\\ + \texttt{negative\_slope}\cdot x,&\mathrm{otherwise}. + \end{cases} + """ + + DRJIT_STRUCT = { 'negative_slope': Union[float, drjit.ArrayBase] } + def __init__(self, negative_slope: Union[float, drjit.ArrayBase] = 1e-2): + self.negative_slope = negative_slope + + def __call__(self, arg: CoopVec, /) -> CoopVec: + return drjit.maximum(arg, 0) + drjit.minimum(arg, 0.0) * self.negative_slope + + +class Exp2(Module): + r""" + Applies the base-2 exponential function to each component. + + .. math:: + + \mathrm{Exp2}(x) = 2^x + + On the CUDA backend, this function directly maps to an efficient native GPU instruction. + """ + DRJIT_STRUCT = { } + def __call__(self, arg: CoopVec, /) -> CoopVec: + return drjit.exp2(arg) + +class Exp(Module): + r""" + Applies the exponential function to each component. + + .. math:: + + \mathrm{Exp}(x) = e^x + """ + DRJIT_STRUCT = { } + def __call__(self, arg: CoopVec, /) -> CoopVec: + return drjit.exp2(arg * (1 / drjit.log(2))) + +class Tanh(Module): + r""" + Applies the hyperbolic tangent function to each component. + + .. math:: + + \mathrm{Tanh}(x) = \frac{\exp(x)-\exp(-x)}{\exp(x)+\exp(-x)} + + On the CUDA backend, this function directly maps to an efficient native GPU instruction. + """ + DRJIT_STRUCT = { } + def __call__(self, arg: CoopVec, /) -> CoopVec: + return drjit.tanh(arg) + +class ScaleAdd(Module): + r""" + Scale the input by a fixed scale and apply an offset. + + Note that ``scale`` and ``offset`` are assumed to be constant (i.e., not trainable). + + .. math:: + + \mathrm{ScaleAdd}(x) = x\cdot\texttt{scale} + \texttt{offset} + """ + DRJIT_STRUCT = {'scale': Union[None, float, int, drjit.ArrayBase], + 'offset': Union[None, float, int, drjit.ArrayBase]} + def __init__(self, scale: Union[float, int, drjit.ArrayBase, None] = None, + offset: Union[float, int, drjit.ArrayBase, None] = None): + self.scale = scale + self.offset = offset + def __call__(self, arg: CoopVec, /) -> CoopVec: + if not self.scale or not self.offset: + raise Exception("drjit.nn.ScaleAdd(): you must set a scale and offset!") + return drjit.fma(arg, self.scale, self.offset) + +class Cast(Module): + """ + Cast the input cooperative vector to a different precision. Should be + instantiated with the desired element type, e.g. ``Cast(drjit.cuda.ad.Float32)`` + """ + DRJIT_STRUCT = { 'dtype': Optional[Type[drjit.ArrayBase]] } + def __init__(self, dtype: Optional[Type[drjit.ArrayBase]] = None): + self.dtype = dtype + def __call__(self, arg: CoopVec, /) -> CoopVec: + return cast(arg, self.dtype) + def __repr__(self): + return f'Cast(dtype={self.dtype.__name__})' + +class Linear(Module): + r""" + This layer represents a learnable affine linear transformation of the input + data following the expression :math:`\mathbf{y} = \mathbf{A}\mathbf{x} + + \mathbf{b}`. + + It takes ``in_features`` inputs and returns a cooperative vector with + ``out_features`` dimensions. The following parameter values have a special + a meaning: + + - ``in_features=-1``: set the input size to match the previous model's + output (or the input of the network, if there is no previous model). + + - ``out_features=-1``: set the output size to match the input size. + + The bias (:math:`\textbf{b}`) term is optional and can be disabled by + specifying ``bias=False``. + + The method :py:func:`Module.alloc` initializes the underlying coefficient + storage with random weights following a uniform Xavier initialization, + i.e., uniform variates on the interval :math:`[-k,k]` where + :math:`k=1/\sqrt{\texttt{out\_features}}`. Call :py:func:`drjit.seed()` prior + to this step to ensure that weights are always initialized with the same + values, which can be helpful for hyperpararameter tuning and + reproducibility. + """ + config: Tuple[int, int, bool] + weights: TensorOrViewOrNone + bias: TensorOrViewOrNone + + DRJIT_STRUCT = { + 'config': Tuple[int, int, bool], + 'weights': TensorOrViewOrNone, + 'bias': TensorOrViewOrNone + } + + def __init__(self, in_features: int = -1, out_features: int = -1, bias = True) -> None: + self.config = (in_features, out_features, bias) + self.weights = self.bias = None + + def __repr__(self) -> str: + s = f'Linear(in_features={self.config[0]}, out_features={self.config[1]}' + if not self.config[2]: + s += ', bias=False' + s += ')' + return s + + def __call__(self, arg: CoopVec, /) -> CoopVec: + if self.weights is None: + raise RuntimeError( + "Uninitialized network. Call 'net = net.alloc(""" + ")' to initialize the weight storage first. Following this, " + "use 'drjit.nn.pack()' to transform the network into an " + "optimal layout for evaluation." + ) + elif not isinstance(self.weights, MatrixView) or \ + (self.bias is not None and not isinstance(self.bias, MatrixView)): + raise RuntimeError( + "Uninitialized network. Use 'drjit.nn.pack()' to transform" + "the network into an optimal layout for evaluation." + ) + return matvec(self.weights, arg, self.bias) + + def _alloc(self, dtype: Type[drjit.ArrayBase], size : int = -1, /) -> Tuple[Module, int]: + in_features, out_features, bias = self.config + if in_features < 0: + in_features = size + if out_features < 0: + out_features = in_features + if in_features == -1 or out_features == -1: + raise RuntimeError("The network contains layers with an unspecified " + "size. You must specify the input size to drjit.nn.Module.alloc().") + + result = Linear(in_features, out_features, bias) + # Xavier (uniform) initialization, matches PyTorch + scale = drjit.sqrt(1 / out_features) + Float32 = drjit.float32_array_t(dtype) + samples = drjit.rand(Float32, (out_features, in_features)) + result.weights = dtype(drjit.fma(samples, 2, -1) * scale) + if bias: + result.bias = drjit.zeros(dtype, out_features) + return result, out_features + +def _sincos_tri(t: T) -> tuple[T, T]: + """Implementation detail of the TriEncode class""" + s = t - .25 + st = s - drjit.round(s) + ct = t - drjit.round(t) + return ( + drjit.fma(drjit.abs(st), -4, 1), + drjit.fma(drjit.abs(ct), -4, 1) + ) + +class TriEncode(Module): + r""" + Map an input onto a higher-dimensional space by transforming it using + triangular sine and cosine approximations of an increasing frequency. + + .. math:: + + x\mapsto \begin{bmatrix} + \sin_\triangle(2^0\,x)\\ + \cos_\triangle(2^0\,x)\\ + \vdots\\ + \cos_\triangle(2^{n-1}\, x)\\ + \sin_\triangle(2^{n-1}\, x) + \end{bmatrix} + + where + + .. math:: + + \cos_\triangle(x) = 1-4\left|x-\mathrm{round}(x)\right| + + and + + .. math:: + + \sin_\triangle(x) = \cos_\triangle(x-1/4) + + The value :math:`n` refers to the number of *octaves*. This layer increases + the dimension by a factor of :math:`2n`. + + Note that this encoding has period 1. If your input exceeds the interval + :math:`[0, 1]`, it is advisable that you reduce it to this range to avoid + losing information. + + Minima/maxima of higher frequency components conincide on a regular + lattice, which can lead to reduced fitting performance at those locations. + Specify the optional parameter ``shift`` to phase-shift the :math:`i`-th + frequency by :math:`2\,\pi\,\mathrm{shift}` to avoid this behavior. + + The following plot shows the first two octaves applied to the linear + function on :math:`[0, 1]` (without shift). + + .. image:: https://rgl.s3.eu-central-1.amazonaws.com/media/uploads/wjakob/2024/06/tri_encode_light.svg + :class: only-light + :width: 600px + :align: center + + .. image:: https://rgl.s3.eu-central-1.amazonaws.com/media/uploads/wjakob/2024/06/tri_encode_dark.svg + :class: only-dark + :width: 600px + :align: center + """ + + DRJIT_STRUCT = { 'octaves' : int, 'shift': float, 'channels': int } + + def __init__(self, octaves: int = 0, shift: float = 0) -> None: + self.octaves = octaves + self.shift = shift + self.channels = -1 + + def _alloc(self, dtype: Type[drjit.ArrayBase], size : int = -1, /) -> Tuple[Module, int]: + r = TriEncode(self.octaves, self.shift) + r.channels = size + return r, size * self.octaves * 2 + + def __repr__(self) -> str: + return f'TriEncode(octaves={self.octaves}, shift={self.shift}, in_channels={self.channels}, out_features={self.channels*self.octaves*2})' + + def __call__(self, arg: CoopVec, /) -> CoopVec: + args, r = list(arg), list() + for arg in args: + for i in range(self.octaves): + s, c = _sincos_tri(drjit.fma(arg, 2**i, self.shift*i)) + r.append(s) + r.append(c) + return CoopVec(r) + + +class SinEncode(Module): + r""" + Map an input onto a higher-dimensional space by transforming it using sines + and cosines of an increasing frequency. + + .. math:: + + x\mapsto \begin{bmatrix} + \sin(2^0\, 2\pi x)\\ + \cos(2^0\, 2\pi x)\\ + \vdots\\ + \sin(2^{n-1}\, 2\pi x)\\ + \cos(2^{n-1}\, 2\pi x)\\ + \end{bmatrix} + + + The value :math:`n` refers to the number of *octaves*. This layer increases + the dimension by a factor of :math:`2n`. + + Note that this encoding has period 1. If your input exceeds the interval + :math:`[0, 1]`, it is advisable that you reduce it to this range to avoid + losing information. + + Minima/maxima of higher frequency components conincide on a regular + lattice, which can lead to reduced fitting performance at those locations. + Specify the optional parameter ``shift`` to phase-shift the :math:`i`-th + frequency by :math:`\mathrm{shift}` radians to avoid this behavior. + + The following plot shows the first two octaves applied to the linear + function on :math:`[0, 1]` (without shift). + + .. image:: https://rgl.s3.eu-central-1.amazonaws.com/media/uploads/wjakob/2024/06/sin_encode_light.svg + :class: only-light + :width: 600px + :align: center + + .. image:: https://rgl.s3.eu-central-1.amazonaws.com/media/uploads/wjakob/2024/06/sin_encode_dark.svg + :class: only-dark + :width: 600px + :align: center + """ + + DRJIT_STRUCT = { 'octaves' : int, 'shift': Union[tuple, None], 'channels': int } + + def __init__(self, octaves: int = 0, shift: float = 0) -> None: + self.octaves = octaves + self.channels = -1 + + if shift == 0: + self.shift = None + else: + self.shift = (drjit.sin(shift * 2 * drjit.pi), + drjit.cos(shift * 2 * drjit.pi)) + + def _alloc(self, dtype: Type[drjit.ArrayBase], size : int = -1, /) -> Tuple[Module, int]: + r = SinEncode(self.octaves) + r.channels = size + r.shift = self.shift + return r, size * self.octaves * 2 + + def __repr__(self) -> str: + return f'SinEncode(octaves={self.octaves}, shift={self.shift}, in_channels={self.channels}, out_features={self.channels*self.octaves*2})' + + def __call__(self, arg: CoopVec, /) -> CoopVec: + args, r = list(arg), list() + for arg in args: + s, c = drjit.sincos(arg * 2 * drjit.pi) + r.append(s) + r.append(c) + for _ in range(1, self.octaves): + # Recurrence for double angle sine/cosine + s2 = 2 * s + s, c = s2 * c, drjit.fma(-s2, s, 1) + r.append(s) + r.append(c) + + if self.shift: + # Recurrence for sine/cosine angle addition + ss, cs = self.shift + s, c = drjit.fma(s, cs, c*ss), \ + drjit.fma(c, cs, -s*ss) + + return CoopVec(r) + + diff --git a/drjit/stubs.pat b/drjit/stubs.pat index a38804d3..437bc732 100644 --- a/drjit/stubs.pat +++ b/drjit/stubs.pat @@ -108,13 +108,30 @@ drjit.select$: @overload def select(arg0: bool | AnyArray, arg1: T, arg2: T) -> T: ... -drjit.(atan2|minimum|maximum)$: +drjit.atan2$: @overload def \1(arg0: ArrayBase[SelfT, SelfCpT, ValT, ValCpT, RedT, PlainT, MaskT], arg1: SelfCpT, /) -> SelfT: \doc @overload def \1(arg0: SelfCpT, arg1: ArrayBase[SelfT, SelfCpT, ValT, ValCpT, RedT, PlainT, MaskT], /) -> SelfT: ... + +drjit.step$: + @overload + def \1(arg0: ArrayBase[SelfT, SelfCpT, ValT, ValCpT, RedT, PlainT, MaskT], arg1: SelfCpT, /) -> SelfT: + \doc + @overload + def \1(arg0: SelfCpT, arg1: ArrayBase[SelfT, SelfCpT, ValT, ValCpT, RedT, PlainT, MaskT], /) -> SelfT: ... + +drjit.(minimum|maximum)$: + @overload + def \1(arg0: ArrayBase[SelfT, SelfCpT, ValT, ValCpT, RedT, PlainT, MaskT], arg1: SelfCpT, /) -> SelfT: + \doc + @overload + def \1(arg0: SelfCpT, arg1: ArrayBase[SelfT, SelfCpT, ValT, ValCpT, RedT, PlainT, MaskT], /) -> SelfT: ... + @overload + def \1(arg0: CoopVec[ArrayT], arg1: object) -> CoopVec[ArrayT]: ... @overload + def \1(arg0: object, arg1: CoopVec[ArrayT]) -> CoopVec[ArrayT]: ... def \1(arg0: T, arg1: T, /) -> T: ... drjit.(empty|zeros|ones)$: @@ -128,21 +145,41 @@ drjit.(full|opaque)$: @overload def \1(dtype: type[T], value: T, shape: int | Sequence[int]) -> T: ... -drjit.(fma|lerp)$: +drjit.lerp$: @overload - def \1(arg0: ArrayBase[SelfT, SelfCpT, ValT, ValCpT, RedT, PlainT, MaskT], arg1: SelfCpT, arg2: SelfCpT, /) -> SelfT: + def lerp(arg0: ArrayBase[SelfT, SelfCpT, ValT, ValCpT, RedT, PlainT, MaskT], arg1: SelfCpT, arg2: SelfCpT, /) -> SelfT: \doc @overload - def \1(arg0: SelfCpT, arg1: ArrayBase[SelfT, SelfCpT, ValT, ValCpT, RedT, PlainT, MaskT], arg2: SelfCpT, /) -> SelfT: ... + def lerp(arg0: SelfCpT, arg1: ArrayBase[SelfT, SelfCpT, ValT, ValCpT, RedT, PlainT, MaskT], arg2: SelfCpT, /) -> SelfT: ... @overload - def \1(arg0: SelfCpT, arg1: SelfCpT, arg2: ArrayBase[SelfT, SelfCpT, ValT, ValCpT, RedT, PlainT, MaskT], /) -> SelfT: ... + def lerp(arg0: SelfCpT, arg1: SelfCpT, arg2: ArrayBase[SelfT, SelfCpT, ValT, ValCpT, RedT, PlainT, MaskT], /) -> SelfT: ... @overload - def \1(arg0: T, arg1: T, arg2: T) -> T: ... + def lerp(arg0: T, arg1: T, arg2: T) -> T: ... + +drjit.fma$: + @overload + def fma(arg0: ArrayBase[SelfT, SelfCpT, ValT, ValCpT, RedT, PlainT, MaskT], arg1: SelfCpT, arg2: SelfCpT, /) -> SelfT: + \doc + @overload + def fma(arg0: SelfCpT, arg1: ArrayBase[SelfT, SelfCpT, ValT, ValCpT, RedT, PlainT, MaskT], arg2: SelfCpT, /) -> SelfT: ... + @overload + def fma(arg0: SelfCpT, arg1: SelfCpT, arg2: ArrayBase[SelfT, SelfCpT, ValT, ValCpT, RedT, PlainT, MaskT], /) -> SelfT: ... + @overload + def fma(arg0: CoopVec[ArrayT], arg1: object, arg2: object) -> CoopVec[ArrayT]: ... + @overload + def fma(arg0: object, arg1: CoopVec[ArrayT], arg2: object) -> CoopVec[ArrayT]: ... + @overload + def fma(arg0: object, arg1: object, arg2: CoopVec[ArrayT]) -> CoopVec[ArrayT]: ... + @overload + def fma(arg0: T, arg1: T, arg2: T) -> T: ... drjit.reshape$: \from typing import Literal + @overload def reshape(dtype: type[T], value: object, shape: int | Sequence[int], order: Literal['A', 'C', 'F'] = 'A', shrink: bool = False) -> T: \doc + @overload + def reshape(value: object, shape: int | Sequence[int], order: Literal['A', 'C', 'F'] = 'A', shrink: bool = False) -> T: ... drjit.(isnan|isinf|isfinite)$: @overload @@ -265,7 +302,6 @@ drjit.sh_eval$: def sh_eval(d: ArrayBase[SelfT, SelfCpT, ValT, ValCpT, RedT, PlainT, MaskT], order: int) -> list[ValT]: \doc - # -------------- drjit.syntax, interop, detail ---------------- # Clean the drjit.interop stub @@ -645,3 +681,6 @@ drjit.__prefix__: \from typing import TypeAlias \from collections.abc import Iterable, Sequence Axis: TypeAlias = int | tuple[int] | None + +drjit.coop.__prefix__: + \from typing import overload, Literal diff --git a/ext/drjit-core b/ext/drjit-core index 32486b64..89c0db27 160000 --- a/ext/drjit-core +++ b/ext/drjit-core @@ -1 +1 @@ -Subproject commit 32486b64dcbf3a8f7d0a28c274863d2c8ea25f65 +Subproject commit 89c0db27ea5a1d4b49e310c058386098933bc5bd diff --git a/include/drjit/extra.h b/include/drjit/extra.h index d2f24dd0..f10dd932 100644 --- a/include/drjit/extra.h +++ b/include/drjit/extra.h @@ -488,6 +488,9 @@ extern DRJIT_EXTRA_EXPORT uint64_t ad_var_map_get(uint64_t index); extern DRJIT_EXTRA_EXPORT int ad_leak_warnings(); extern DRJIT_EXTRA_EXPORT void ad_set_leak_warnings(int value); +/// Extract the i-th predecessor of an AD node (or return 0) +extern DRJIT_EXTRA_EXPORT uint32_t ad_pred(uint32_t index, uint32_t i); + #if defined(__GNUC__) DRJIT_INLINE uint64_t ad_var_inc_ref(uint64_t index) JIT_NOEXCEPT { /* If 'index' is known at compile time, it can only be zero, in @@ -520,6 +523,34 @@ DRJIT_INLINE void ad_var_dec_ref(uint64_t index) JIT_NOEXCEPT { // Return the AD reference count of a variable (for debugging) extern DRJIT_EXTRA_EXPORT uint32_t ad_var_ref(uint64_t index); +/// --------------------- Cooperative vector API --------------------- + +/// Pack a set of regular Dr.Jit variables to form a cooperative vector +extern DRJIT_EXTRA_EXPORT uint64_t ad_coop_vec_pack(uint32_t n, const uint64_t *in); + +/// Unpack a cooperative vector into its components +extern DRJIT_EXTRA_EXPORT void ad_coop_vec_unpack(uint64_t index, uint32_t n, uint64_t *out); + +/// Perform a unary operation on a cooperative vector +extern DRJIT_EXTRA_EXPORT uint64_t ad_coop_vec_unary_op(JitOp op, uint64_t a0); + +/// Perform a binary operation on a pair of cooperative vectors +extern DRJIT_EXTRA_EXPORT uint64_t ad_coop_vec_binary_op(JitOp op, uint64_t a0, uint64_t a1); + +/// Perform a ternary operation on a triplet of cooperative vectors +extern DRJIT_EXTRA_EXPORT uint64_t ad_coop_vec_ternary_op(JitOp op, uint64_t a0, uint64_t a1, uint64_t a2); + +/// Perform a matrix-vector multiplication + bias addition +extern DRJIT_EXTRA_EXPORT uint64_t ad_coop_vec_matvec(uint64_t A_index, + const MatrixDescr *A_descr, + uint64_t x_index, + uint64_t b_index, + const MatrixDescr *b_descr, + int transpose); + +/// Cast a cooperative vector to a different precision +extern DRJIT_EXTRA_EXPORT uint64_t ad_coop_vec_cast(uint64_t index, VarType vt); + #if defined(__cplusplus) } #endif diff --git a/src/extra/autodiff.cpp b/src/extra/autodiff.cpp index 11c5fda2..6e45b62b 100644 --- a/src/extra/autodiff.cpp +++ b/src/extra/autodiff.cpp @@ -43,6 +43,7 @@ */ #include "common.h" +#include "drjit-core/jit.h" #include #include #include @@ -140,6 +141,26 @@ DRJIT_NOINLINE JitVar scalar(JitBackend backend, VarType type, double value) { } } +/// As above, but for cooperative vectors +DRJIT_NOINLINE JitVar scalar_coop_vec(JitBackend backend, VarType type, double value, uint32_t length) { + switch (type) { + case VarType::Float16: { + drjit::half v = (drjit::half) value; + return JitVar::steal(jit_coop_vec_literal(backend, VarType::Float16, &v, 1, length)); + } + case VarType::Float32: { + float v = (float) value; + return JitVar::steal(jit_coop_vec_literal(backend, VarType::Float32, &v, 1, length)); + } + case VarType::Float64: { + return JitVar::steal(jit_coop_vec_literal(backend, VarType::Float64, &value, 1, length)); + } + default: + ad_fail("scalar_coop_vec(): unsupported AD scalar type"); + return JitVar(); + } +} + /// Create a scalar Jit variable with the same floating point type and backend /// as an already existing variable with the provided ``index`` DRJIT_INLINE JitVar scalar(Index index, double value) { @@ -147,6 +168,12 @@ DRJIT_INLINE JitVar scalar(Index index, double value) { return scalar(info.backend, info.type, value); } +/// As above, but for cooperative vectors +DRJIT_INLINE JitVar scalar_coop_vec(Index index, double value) { + VarInfo info = jit_set_backend(jit_index(index)); + return scalar_coop_vec(info.backend, info.type, value, jit_coop_vec_length(index)); +} + // ========================================================================== // Central data structures: edges, variables, global state // ========================================================================== @@ -216,7 +243,10 @@ enum VariableFlags : uint8_t { Visited = 1 << 4, /// Is this variable on an iteration boundary of an evaluated loop? - LoopBoundary = 1 << 5 + LoopBoundary = 1 << 5, + + /// Does this variable store a cooperative vector? + CoopVec = 1 << 6 }; /** @@ -311,6 +341,15 @@ struct Variable { void mul_accum(const JitVar &v1, const JitVar &v2, size_t src_size) { JitVar zero = scalar(v1.index(), 0.f), weight; + if (unlikely(flags & CoopVec)) { + // Specialized gradient propagation for cooperative vectors + if (grad.valid()) + grad = JitVar::steal(jit_coop_vec_ternary_op(JitOp::Fma, v1.index(), v2.index(), grad.index())); + else + grad = JitVar::steal(jit_coop_vec_binary_op(JitOp::Mul, v1.index(), v2.index())); + return; + } + // Elide the zero check if ``v2`` is known not to be NaN/infinite if (jit_var_is_finite_literal(v2.index())) weight = v2; @@ -353,6 +392,15 @@ struct Variable { * optimizations. */ void accum(const JitVar& v, size_t src_size) { + if (unlikely(flags & CoopVec)) { + // Specialized gradient propagation for cooperative vectors + if (grad.valid()) + grad = JitVar::steal(jit_coop_vec_binary_op(JitOp::Add, v.index(), grad.index())); + else + grad = v; + return; + } + if (size == 1 && src_size != 1) { /* When this variable is scalar (size == 1) and the source is not (src_size != 1), the gradient must be reduced to a single @@ -857,6 +905,9 @@ static void ad_propagate_size(Variable *v) { } } +/// A tag to signal cooperative weights in the Arg() constructor +struct coop { }; + // This data structure encodes an ordinary dependence on a function argument struct Arg { Arg() = default; @@ -867,6 +918,9 @@ struct Arg { Arg(Index index, double value) : ad_index(::ad_index(index)), weight(scalar(index, value)) { } + Arg(Index index, double value, coop) + : ad_index(::ad_index(index)), weight(scalar_coop_vec(index, value)) { } + Arg(Arg &&a) = default; Arg(const Arg &a) = delete; Arg &operator=(const Arg &a) = delete; @@ -1016,6 +1070,8 @@ DRJIT_NOINLINE Index ad_var_new_impl(const char *label, JitVar &&result, auto [ad_index, var] = ad_var_new(info.backend, info.size, info.type, symbolic, reuse_indices, label); + if (info.is_coop_vec) + var->flags |= VariableFlags::CoopVec; const char *tname = jit_type_name(info.type); if constexpr (N == 0) { @@ -1177,7 +1233,7 @@ void ad_accum_grad(Index index, JitIndex value) { size_t size_in = value_v.size(); if (v->size != size_in && size_in != 1 && size_in != 0 && v->size != 1) - ad_raise("ad_set_grad(): attempted to store a gradient of size " + ad_raise("ad_accum_grad(): attempted to store a gradient of size " "%zu into AD variable a%u, which has size %zu!", size_in, ad_index, v->size); @@ -2288,6 +2344,23 @@ void ad_mark_loop_boundary(Index index) { } } +uint32_t ad_pred(uint32_t ad_index, uint32_t i_) { + if (ad_index == 0) + return 0; + + std::lock_guard guard(state.lock); + const Variable *v = state[ad_index]; + uint32_t edge = v->next_bwd; + + for (uint32_t i = 0; i < i_; ++i) { + if (!edge) + return 0; + edge = state.edges[edge].next_bwd; + } + + return state.edges[edge].source; +} + // ========================================================================== // Implementation of arithmetic operations and transcendental functions @@ -2872,6 +2945,7 @@ Index ad_var_cast(Index i0, VarType vt) { void ad_var_map_put(Index source, Index target) { uint32_t ad_index_source = ad_index(source), ad_index_target = ad_index(target); + ad_log("ad_var_map_put(): a%u -> a%u", ad_index_source, ad_index_target); if (ad_index_target == 0) return; @@ -2918,8 +2992,7 @@ Index ad_var_map_get(Index index) { /// Potentially use ad_var_map_get to rewrite the source or target of a /// gatter/scatter operation static Index ad_var_memop_remap(Index index, bool input) { - uint32_t flags = jit_flags(); - if (flags & (uint32_t) JitFlag::SymbolicScope) { + if (jit_flags() & (uint32_t) JitFlag::SymbolicScope) { index = ad_var_map_get(index); // Add to set of implicit variable dependencies @@ -3047,8 +3120,8 @@ class PacketGather : public dr::detail::CustomOpBase { void ad_var_gather_packet(size_t n, Index source, JitIndex offset, JitIndex mask, uint64_t *out, ReduceMode mode) { - uint32_t *out2 = (uint32_t *) alloca(sizeof(uint32_t) * n); - jit_var_gather_packet(n, jit_index(source), offset, mask, out2); + uint32_t *tmp = (uint32_t *) alloca(sizeof(uint32_t) * n); + jit_var_gather_packet(n, jit_index(source), offset, mask, tmp); ADIndex source_ad = ad_index(source); const std::vector &scopes = local_state.scopes; @@ -3064,16 +3137,16 @@ void ad_var_gather_packet(size_t n, Index source, JitIndex offset, op->add_index(backend, source_ad, true); for (size_t i = 0; i < n; ++i) { - out[i] = ad_var_new(out2[i]); - jit_var_dec_ref(out2[i]); + out[i] = ad_var_new(tmp[i]); + jit_var_dec_ref(tmp[i]); op->add_output(ad_index(out[i])); } if (!ad_custom_op(op.get())) - ad_raise("ad_var_gather_packet(): could not create CustomOp"); + ad_raise("ad_var_gather_packet(): could not create CustomOp!"); } else { for (size_t i = 0; i < n; ++i) - out[i] = out2[i]; + out[i] = tmp[i]; } } @@ -3152,11 +3225,11 @@ class PacketScatter : public dr::detail::CustomOpBase { if (op == ReduceOp::Identity && mode != ReduceMode::Permute) { JitMask value(true); - uint32_t *values = (uint32_t *) alloca(sizeof(uint32_t)*n); + uint32_t *tmp = (uint32_t *) alloca(sizeof(uint32_t)*n); for (size_t i = 0; i < n; ++i) - values[i] = value.index(); + tmp[i] = value.index(); m_blend = JitMask::steal(jit_var_scatter_packet( - n, m_blend.index(), values, offset, mask)); + n, m_blend.index(), tmp, offset, mask)); } if (op != ReduceOp::Add && op != ReduceOp::Identity) @@ -3171,7 +3244,7 @@ class PacketScatter : public dr::detail::CustomOpBase { void forward() override { std::lock_guard guard(state.lock); - JitIndex *grad_in = (JitIndex *) alloca(sizeof(JitIndex) * m_n); + JitIndex *grad_in = (JitIndex *) alloca(sizeof(JitIndex) * m_n); size_t n_valid = 0; JitVar zero = scalar(m_backend, m_type, 0.0); @@ -3212,7 +3285,7 @@ class PacketScatter : public dr::detail::CustomOpBase { void backward() override { std::lock_guard guard(state.lock); - JitIndex *out = (JitIndex *) alloca(sizeof(JitIndex) * m_n); + JitIndex *out = (JitIndex *) alloca(sizeof(JitIndex) * m_n); Variable *v = state[m_output_indices[0]]; if (!v->grad.valid()) @@ -3268,23 +3341,20 @@ class PacketScatter : public dr::detail::CustomOpBase { Index ad_var_scatter_packet(size_t n, Index target, const Index *values, JitIndex offset, JitIndex mask, ReduceOp op, ReduceMode mode) { - JitIndex *values2 = (JitIndex *) alloca(sizeof(JitIndex) * n); + JitIndex *tmp = (JitIndex *) alloca(sizeof(JitIndex) * n); bool attached = ad_index(target) != 0; for (size_t i = 0; i < n; ++i) { Index index = values[i]; - values2[i] = jit_index(index); - if (ad_index(index)) - attached = true; + tmp[i] = jit_index(index); + attached |= ad_index(index) != 0; } JitVar result = JitVar::steal(jit_var_scatter_packet( - n, jit_index(target), values2, offset, mask, op, mode)); + n, jit_index(target), tmp, offset, mask, op, mode)); bool perm_scatter = op == ReduceOp::Identity && mode == ReduceMode::Permute; - if (!attached) { - return result.release(); - } else { + if (attached) { // Track implicit dependencies & potentially remap variable IDs target = ad_var_memop_remap(target, false); @@ -3299,11 +3369,13 @@ Index ad_var_scatter_packet(size_t n, Index target, const Index *values, uint64_t ad_result = ad_var_new(result.index()); ps->add_output(ad_index(ad_result)); - if (!ad_custom_op(ps.get())) - ad_raise("ad_var_scatter_packet(): could not create CustomOp"); + if (ad_custom_op(ps.get())) + return ad_result; - return ad_result; + ad_var_dec_ref(ad_result); } + + return result.release(); } void ad_var_scatter_add_kahan(Index *target_1, Index *target_2, Index value, @@ -3563,9 +3635,10 @@ const char *ad_var_graphviz() { if (v->flags & VariableFlags::Symbolic) buffer.put("|{Symbolic}"); - buffer.fmt("|{Type: %s|Size: %zu}|{a%u|Refs: %u}}\"", - type_name_short[v->type], v->size, - index, (uint32_t) v->ref_count); + buffer.fmt("|{Type: %s%s|Size: %zu}|{a%u|Refs: %u}}\"", + type_name_short[v->type], + (v->flags & VariableFlags::CoopVec) ? " [coop]" : "", + v->size, index, (uint32_t) v->ref_count); if (color) buffer.fmt(" fillcolor=%s style=filled", color); @@ -3605,7 +3678,7 @@ const char *ad_var_graphviz() { " l4 [style=filled fillcolor=yellowgreen label=\"Gradient present\"];\n" " l3 [style=filled fillcolor=salmon label=\"Input\"];\n" " l2 [style=filled fillcolor=lightblue2 label=\"Output\"];\n" - " l1 [style=filled fillcolor=wheat label=\"Labeled\"];\n" + " l0 [style=filled fillcolor=wheat label=\"Labeled\"];\n" " }\n" "}\n"); @@ -3739,6 +3812,524 @@ void ad_copy_implicit_deps(drjit::vector& result, bool input) { } } +// ========================================================================== +// Cooperative vector API +// ========================================================================== + +class CoopVecPack : public dr::detail::CustomOpBase { +public: + ~CoopVecPack() { + std::lock_guard guard(state.lock); + for (uint32_t index: m_output_indices) + ad_var_dec_ref_int(index, state[index]); + } + + void forward() override { + std::lock_guard guard(state.lock); + uint32_t size = (uint32_t) m_inputs.size(); + JitIndex *tmp = (JitIndex *) alloca(sizeof(JitIndex) * size); + size_t n_valid = 0; + + Variable *target = state[m_output_indices[0]]; + JitVar zero = scalar(m_backend, (VarType) target->type, 0.0); + + for (uint32_t i = 0; i < size; ++i) { + tmp[i] = zero.index(); + + if (m_inputs[i]) { + Variable *v2 = state[m_inputs[i]]; + if (v2->grad.valid()) { + tmp[i] = v2->grad.index(); + n_valid++; + } + } + } + + if (n_valid) { + JitVar packed = JitVar::steal(jit_coop_vec_pack(size, tmp)); + target->accum(packed, target->size); + } + } + + void backward() override { + std::lock_guard guard(state.lock); + uint32_t n = (uint32_t) m_inputs.size(); + + Variable *v = state[m_output_indices[0]]; + if (!v->grad.valid()) + return; + + JitIndex *tmp = (JitIndex *) alloca(sizeof(JitIndex) * n); + jit_coop_vec_unpack(v->grad.index(), n, tmp); + + for (size_t i = 0; i < m_inputs.size(); ++i) { + uint32_t index = m_inputs[i]; + JitVar var = JitVar::steal(tmp[i]); + if (!index) + continue; + Variable *v2 = state[index]; + v2->accum(var, v2->size); + } + } + + void add_input(JitBackend backend, uint32_t index) { + add_index(backend, index, true); + // No need for extra reference counting + m_inputs.push_back(index); + } + + void add_output(JitBackend backend, uint32_t index) { + add_index(backend, index, false); + std::lock_guard guard(state.lock); + ad_var_inc_ref_int(index, state[index]); + } + + const char *name() const override { return "pack"; } + +private: + std::vector m_inputs; +}; + +/// Pack a set of regular Dr.Jit variables to form a cooperative vector +Index ad_coop_vec_pack(uint32_t n, const Index *in) { + JitIndex *tmp = (JitIndex *) alloca(sizeof(JitIndex) * n); + bool attached = false; + + if (n == 0) + return 0; + + for (uint32_t i = 0; i < n; ++i) { + Index index = in[i]; + tmp[i] = jit_index(index); + attached |= ad_index(index) != 0; + } + + JitVar result = JitVar::steal(jit_coop_vec_pack(n, tmp)); + + if (attached) { + VarInfo vi = jit_set_backend(result.index()); + + ref ps = new CoopVecPack(); + for (size_t i = 0; i < n; ++i) + ps->add_input(vi.backend, ad_index(in[i])); + + uint64_t ad_result = ad_var_new(result.index()); + ps->add_output(vi.backend, ad_index(ad_result)); + + if (ad_custom_op(ps.get())) + return ad_result; + + ad_var_dec_ref(ad_result); + } + + return result.release(); +} + +class CoopVecUnpack : public dr::detail::CustomOpBase { +public: + ~CoopVecUnpack() { + std::lock_guard guard(state.lock); + for (ADIndex index : m_output_indices) + ad_var_dec_ref_int(index, state[index]); + } + + void forward() override { + std::lock_guard guard(state.lock); + size_t n = m_output_indices.size(); + + const Variable *v = state[m_input_indices[0]]; + if (!v->grad.valid()) + return; + + JitIndex *tmp = (JitIndex *) alloca(sizeof(JitIndex) * n); + jit_coop_vec_unpack(v->grad.index(), n, tmp); + + for (size_t i = 0; i < n; ++i) { + Variable *vo = state[m_output_indices[i]]; + vo->accum(JitVar::steal(tmp[i]), vo->size); + } + } + + void backward() override { + std::lock_guard guard(state.lock); + size_t n = m_output_indices.size(); + + JitIndex *tmp = (JitIndex *) alloca(sizeof(JitIndex) * n); + + for (size_t i = 0; i < n; ++i) { + const Variable *v = state[m_output_indices[i]]; + uint32_t index = v->grad.index(); + + if (index) + jit_var_inc_ref(index); + else + index = scalar(m_backend, (VarType) v->type, 0.0).release(); + + tmp[i] = index; + } + + JitVar packed = JitVar::steal(jit_coop_vec_pack(n, tmp)); + for (size_t i = 0; i < m_output_indices.size(); ++i) + jit_var_dec_ref(tmp[i]); + + Variable *source = state[m_input_indices[0]]; + source->accum(packed, source->size); + } + + void add_output(uint32_t index) { + add_index(m_backend, index, false); + + std::lock_guard guard(state.lock); + ad_var_inc_ref_int(index, state[index]); + } + + const char *name() const override { return "unpack"; } +}; + +/// Unpack a cooperative vector into its components +void ad_coop_vec_unpack(uint64_t index, uint32_t n, uint64_t *out) { + uint32_t *tmp = (uint32_t *) alloca(sizeof(uint32_t) * n); + jit_coop_vec_unpack(index, n, tmp); + + ADIndex ad_index = ::ad_index(index); + const std::vector &scopes = local_state.scopes; + if (!scopes.empty()) + scopes.back().maybe_disable(ad_index); + + if (ad_index) { + ref op = new CoopVecUnpack(); + JitBackend backend = jit_set_backend(jit_index(index)).backend; + op->add_index(backend, ad_index, true); + + for (uint32_t i = 0; i < n; ++i) { + out[i] = ad_var_new(tmp[i]); + jit_var_dec_ref(tmp[i]); + op->add_output(::ad_index(out[i])); + } + + if (!ad_custom_op(op.get())) + ad_raise("ad_coop_vec_unpack(): could not create CustomOp!"); + } else { + for (uint32_t i = 0; i < n; ++i) + out[i] = tmp[i]; + } +} + +/// Perform a unary operation on a cooperative vector +uint64_t ad_coop_vec_unary_op(JitOp op, uint64_t i0) { + JitVar result = JitVar::steal( + jit_coop_vec_unary_op(op, jit_index(i0))); + + if (is_detached(i0)) { + return result.release(); + } else { + switch (op) { + case JitOp::Exp2: { + JitVar scale = scalar_coop_vec(i0, dr::LogTwo); + JitVar w0 = JitVar::steal(jit_coop_vec_binary_op(JitOp::Mul, result.index(), scale.index())); + return ad_var_new("exp2", std::move(result), Arg(i0, std::move(w0))); + } + + case JitOp::Tanh: { + // Mini-max polynomial fit made using Sollya (max. relative error = 0.0052) + // Q1 = fpminimax(4*y/((1 + y)^2)-y, [|1, 2, 3, 4, 5|], [|halfprecision...|], [0, 1-1e-20]); + // Q2 = horner(Q1 + y); + // print(Q2); + + JitVar scale = scalar_coop_vec(i0, -2.8853900817779268), // -2/log(2) + c0 = scalar_coop_vec(i0, 3.98046875), + c1 = scalar_coop_vec(i0, -7.4140625), + c2 = scalar_coop_vec(i0, 8.2421875), + c3 = scalar_coop_vec(i0, -5.1640625), + c4 = scalar_coop_vec(i0, 1.35546875); + + JitVar x0 = JitVar::steal(jit_coop_vec_binary_op(JitOp::Mul, (uint32_t) i0, scale.index())), + x1 = JitVar::steal(jit_coop_vec_unary_op(JitOp::Exp2, x0.index())), + y0 = JitVar::steal(jit_coop_vec_ternary_op(JitOp::Fma, x1.index(), c4.index(), c3.index())), + y1 = JitVar::steal(jit_coop_vec_ternary_op(JitOp::Fma, x1.index(), y0.index(), c2.index())), + y2 = JitVar::steal(jit_coop_vec_ternary_op(JitOp::Fma, x1.index(), y1.index(), c1.index())), + y3 = JitVar::steal(jit_coop_vec_ternary_op(JitOp::Fma, x1.index(), y2.index(), c0.index())), + y4 = JitVar::steal(jit_coop_vec_binary_op(JitOp::Mul, x1.index(), y3.index())); + + return ad_var_new("tanh", std::move(result), Arg(i0, std::move(y4))); + } + + default: + ad_raise("ad_coop_vec_unary_op(): differentiable version not implemented."); + } + } +} + +/// Perform a binary operation on a pair of cooperative vectors +uint64_t ad_coop_vec_binary_op(JitOp op, uint64_t i0, uint64_t i1) { + JitVar result = JitVar::steal( + jit_coop_vec_binary_op(op, jit_index(i0), jit_index(i1))); + + if (is_detached(i0, i1)) { + return result.release(); + } else { + switch (op) { + case JitOp::Add: + return ad_var_new("add", std::move(result), + Arg(i0, 1.0, coop{}), + Arg(i1, 1.0, coop{})); + break; + + case JitOp::Sub: + return ad_var_new("sub", std::move(result), + Arg(i0, 1.0, coop{}), + Arg(i1, -1.0, coop{})); + break; + + case JitOp::Mul: + return ad_var_new("mul", std::move(result), + Arg(i0, JitVar::borrow(jit_index(i1))), + Arg(i1, JitVar::borrow(jit_index(i0)))); + break; + + case JitOp::Min: { + JitVar w0 = JitVar::steal(jit_coop_vec_binary_op(JitOp::Step, jit_index(i0), jit_index(i1))), + w1 = JitVar::steal(jit_coop_vec_binary_op(JitOp::Step, jit_index(i1), jit_index(i0))); + return ad_var_new("min", std::move(result), + Arg(i0, std::move(w1)), + Arg(i1, std::move(w0))); + } + break; + + case JitOp::Max: { + JitVar w0 = JitVar::steal(jit_coop_vec_binary_op(JitOp::Step, jit_index(i0), jit_index(i1))), + w1 = JitVar::steal(jit_coop_vec_binary_op(JitOp::Step, jit_index(i1), jit_index(i0))); + return ad_var_new("max", std::move(result), + Arg(i0, std::move(w0)), + Arg(i1, std::move(w1))); + } + break; + + case JitOp::Step: + return result.release(); + + default: + ad_raise("ad_coop_vec_binary_op(): differentiable version not implemented."); + } + } +} + +/// Perform a ternary operation on a triplet of cooperative vectors +uint64_t ad_coop_vec_ternary_op(JitOp op, uint64_t i0, uint64_t i1, uint64_t i2) { + JitVar result = JitVar::steal( + jit_coop_vec_ternary_op(op, jit_index(i0), jit_index(i1), jit_index(i2))); + + if (is_detached(i0, i1, i2)) { + return result.release(); + } else { + switch (op) { + case JitOp::Fma: + return ad_var_new("fma", std::move(result), + Arg(i0, JitVar::borrow(jit_index(i1))), + Arg(i1, JitVar::borrow(jit_index(i0))), + Arg(i2, 1.0, coop{})); + + default: + ad_raise("ad_coop_vec_ternary_op(): differentiable version not implemented."); + } + } +} + + +struct CoopCast : Special { + CoopCast(VarType v1, VarType v2) : v1(v1), v2(v2) { } + + void backward(Variable *source, const Variable *target) override { + source->accum(JitVar::steal(jit_coop_vec_cast(target->grad.index(), v1)), + target->size); + } + + void forward(const Variable *source, Variable *target) override { + target->accum(JitVar::steal(jit_coop_vec_cast(source->grad.index(), v2)), + source->size); + } + + VarType v1, v2; +}; + + +Index ad_coop_vec_cast(Index i0, VarType vt) { + JitVar result = JitVar::steal(jit_coop_vec_cast(jit_index(i0), vt)); + + if (is_detached(i0)) { + return result.release(); + } else { + return ad_var_new("cast", std::move(result), + SpecialArg(i0, new CoopCast(jit_var_type((JitIndex) i0), vt))); + } +} + + +class CoopMatVec : public dr::detail::CustomOpBase { +public: + CoopMatVec(Index A_index, const MatrixDescr *A_descr, Index x_index, + Index b_index, const MatrixDescr *b_descr, int transpose) + : m_A(A_index), m_A_descr(*A_descr), m_x(x_index), m_b(b_index), + m_transpose(transpose) { + if (b_descr) + m_b_descr = *b_descr; + ad_var_inc_ref(m_A); + ad_var_inc_ref(m_b); + ad_var_inc_ref(m_x); + m_out = 0; + } + + ~CoopMatVec() { + ad_var_dec_ref(m_A); + ad_var_dec_ref(m_b); + ad_var_dec_ref(m_x); + ad_var_dec_ref(m_out); + } + + void forward() override { + std::lock_guard guard(state.lock); + + const Variable *A_v = ad_index(m_A) ? state[ad_index(m_A)] : nullptr, + *x_v = ad_index(m_x) ? state[ad_index(m_x)] : nullptr, + *b_v = ad_index(m_b) ? state[ad_index(m_b)] : nullptr; + + Variable *out_v = state[m_output_indices[0]]; + bool has_b_grad = b_v && b_v->grad.valid(); + + if (A_v && A_v->grad.valid()) { + JitVar result = JitVar::steal(jit_coop_vec_matvec( + A_v->grad.index(), &m_A_descr, jit_index(m_x), + has_b_grad ? b_v->grad.index() : 0, + has_b_grad ? &m_b_descr : nullptr, m_transpose)); + out_v->accum(result, out_v->size); + has_b_grad = false; + } + + if (x_v && x_v->grad.valid()) { + JitVar result = JitVar::steal(jit_coop_vec_matvec( + jit_index(m_A), &m_A_descr, x_v->grad.index(), + has_b_grad ? b_v->grad.index() : 0, + has_b_grad ? &m_b_descr : nullptr, m_transpose)); + out_v->accum(result, out_v->size); + has_b_grad = false; + } + + if (has_b_grad) { + JitVar result = JitVar::steal(jit_coop_vec_load( + b_v->grad.index(), m_b_descr.offset, m_b_descr.rows)); + out_v->accum(result, out_v->size); + } + } + + void backward() override { + std::lock_guard guard(state.lock); + Variable *out_v = state[m_output_indices[0]]; + const JitVar &grad = out_v->grad; + + if (!grad.valid()) + return; + + Variable *A_v = ad_index(m_A) ? state[ad_index(m_A)] : nullptr, + *x_v = ad_index(m_x) ? state[ad_index(m_x)] : nullptr, + *b_v = ad_index(m_b) ? state[ad_index(m_b)] : nullptr; + + if (x_v) { + JitVar result = JitVar::steal(jit_coop_vec_matvec( + jit_index(m_A), &m_A_descr, grad.index(), 0, + nullptr, m_transpose == 0 ? 1 : 0)); + x_v->accum(result, x_v->size); + } + + if (A_v) { + uint32_t vec_a = jit_index(m_x), + vec_b = jit_index(grad.index()); + if (m_transpose) + std::swap(vec_a, vec_b); + + A_v->grad = JitVar::steal(jit_coop_vec_outer_product_accum( + A_v->grad.index(), jit_var_size(jit_index(m_A)), &m_A_descr, + vec_b, vec_a)); + } + + if (b_v) { + b_v->grad = JitVar::steal(jit_coop_vec_accum( + b_v->grad.index(), jit_var_size(jit_index(m_b)), m_b_descr.offset, + grad.index())); + } + } + + void set_output(JitBackend backend, Index index) { + add_index(backend, ad_index(index), false); + m_out = (index >> 32) << 32; + ad_var_inc_ref(m_out); + } + + const char *name() const override { return "matvec"; } + +private: + Index m_A; + MatrixDescr m_A_descr; + Index m_x; + Index m_b; + Index m_out; + MatrixDescr m_b_descr; + int m_transpose; +}; + +uint64_t ad_coop_vec_matvec(uint64_t A_index, const MatrixDescr *A_descr, + uint64_t x_index, uint64_t b_index, + const MatrixDescr *b_descr, int transpose) { + + uint32_t A_index_j = jit_index(A_index), + x_index_j = jit_index(x_index), + b_index_j = jit_index(b_index), + A_index_a = ad_index(A_index), + x_index_a = ad_index(x_index), + b_index_a = ad_index(b_index); + + if (A_index_a || x_index_a || b_index_a) { + const std::vector &scopes = local_state.scopes; + if (!scopes.empty()) { + const Scope &s = scopes.back(); + s.maybe_disable(A_index_a); + s.maybe_disable(x_index_a); + s.maybe_disable(b_index_a); + } + } + + JitVar result = JitVar::steal(jit_coop_vec_matvec( + A_index_j, A_descr, x_index_j, b_index_j, b_descr, transpose)); + + if (!A_index_a && !x_index_a && !b_index_a) { + return result.release(); + } else { + { + std::lock_guard guard(state.lock); + A_index = ad_var_memop_remap(A_index, true); + b_index = ad_var_memop_remap(b_index, true); + A_index_j = jit_index(A_index); + b_index_j = jit_index(b_index); + A_index_a = ad_index(A_index); + b_index_a = ad_index(b_index); + } + + ref op = new CoopMatVec(A_index, A_descr, x_index, + b_index, b_descr, transpose); + JitBackend backend = jit_set_backend(x_index_j).backend; + op->add_index(backend, A_index_a, true); + op->add_index(backend, x_index_a, true); + op->add_index(backend, b_index_a, true); + + uint64_t result_diff = ad_var_new(result.index()); + op->set_output(backend, result_diff); + + if (!ad_custom_op(op.get())) + ad_raise("ad_coop_vec_matvec(): could not create CustomOp!"); + + return result_diff; + } +} + // ========================================================================== // Custom operations // ========================================================================== diff --git a/src/extra/loop.cpp b/src/extra/loop.cpp index ddcb66ab..6b246eb6 100644 --- a/src/extra/loop.cpp +++ b/src/extra/loop.cpp @@ -156,6 +156,7 @@ static size_t ad_loop_evaluated_mask(JitBackend backend, const char *name, JitVar active_it; size_t it = 0; bool grad_suspended = ad_grad_suspended(); + dr::vector copy_bit(indices1.size(), true); while (true) { // Evaluate the loop state @@ -190,14 +191,27 @@ static size_t ad_loop_evaluated_mask(JitBackend backend, const char *name, } for (size_t i = 0; i < indices2.size(); ++i) { - // Kernel caching: Must create an AD copy so that gradient - // computation steps involving this variable (even if unchangecd - // & only used as a read-only dependency) are correctly placed - // within their associated loop iterations. This does not create - // a copy of the underlying JIT variable. - - uint64_t i1 = indices2[i], - i2 = grad_suspended ? ad_var_inc_ref(i1) : ad_var_copy(i1); + // Potentially create an AD copy here (i.e., assign a new AD node + // representing a copy of the original loop state). Note that this + // is a symbolic copy in the AD graph that does not consume actual + // device memory. This copy is needed to prevent a degeneracy of + // forward derivative propagation, where a loop variable does not + // change at all, yet all loop iterations depend on this variable in + // a differentiable sense. When the AD traversal reaches this + // variable, this will generate a huge kernel that propagates the + // derivative to every single loop iteration, instead of splitting + // the computation into per-iteration kernels. By creating marked + // (ad_mark_loop-boundary) copies, we can ensure correct sequencing. + // The extra copies are only used within the loop and removed below. + + uint64_t i1 = indices2[i], i2; + + if (!grad_suspended && (i1 >> 32) != 0 && (indices1[i] >> 32) == (i1 >> 32)) { + i2 = ad_var_copy(i1); + } else { + i2 = ad_var_inc_ref(i1); + copy_bit[i] = false; + } ad_var_dec_ref(i1); ad_mark_loop_boundary(i2); @@ -208,7 +222,8 @@ static size_t ad_loop_evaluated_mask(JitBackend backend, const char *name, write_cb(payload, indices2, false); indices1.release(); - indices1.swap(indices2); + indices2.release(); + read_cb(payload, indices1); active_it = JitVar::borrow(cond_cb(payload)); active_it.schedule_(); @@ -216,6 +231,29 @@ static size_t ad_loop_evaluated_mask(JitBackend backend, const char *name, active.schedule_force_(); } + { + bool changed = false; + for (size_t i = 0; i < indices1.size(); ++i) { + if (!copy_bit[i]) + continue; + + // The AD index of this was copied a number of times (see above for + // the rationale). Let's now remove these again. + uint32_t ad_index = (uint32_t) (indices1[i] >> 32); + for (uint32_t j = 0; j < it; ++j) + ad_index = ad_pred(ad_index, 0); + + uint64_t index_new = (((uint64_t) ad_index) << 32) | (uint32_t) indices1[i]; + ad_var_inc_ref(index_new); + ad_var_dec_ref(indices1[i]); + indices1[i] = index_new; + changed = true; + } + + if (changed) + write_cb(payload, indices1, false); + } + return it; } @@ -705,8 +743,7 @@ struct LoopOp : public dr::detail::CustomOpBase { before the loop> state_i = + tracking was enabled before the loop> grad_state_o = @@ -718,26 +755,21 @@ struct LoopOp : public dr::detail::CustomOpBase { dr.disable_grad(state) */ - void bwd_body_simple() { - // Create differentiable loop state variables + void bwd_body_simple() { // Create differentiable loop state variables m_state2.release(); - index32_vector tmp; - size_t offset = m_inputs.size(); for (size_t i = 0; i < m_inputs.size(); ++i) { const Input &in = m_inputs[i]; uint64_t index; - if (in.has_grad_out && in.has_grad_in) { + if (in.has_grad_in) { index = ad_var_new((uint32_t) m_state[i]); - tmp.push_back_borrow((uint32_t) m_state[offset]); + ad_var_map_put(combine(m_input_indices[in.grad_in_index]), index); } else { index = ad_var_inc_ref(m_state[i]); } - m_state2.push_back_steal(index); - if (in.has_grad_out) - offset++; + m_state2.push_back_steal(index); } // Run the loop body @@ -754,16 +786,18 @@ struct LoopOp : public dr::detail::CustomOpBase { m_read_cb(m_payload, m_state2); // AD backward propagation pass - offset = m_inputs.size(); + size_t offset = m_inputs.size(); for (size_t i = 0; i < m_inputs.size(); ++i) { const Input &in = m_inputs[i]; - if (!in.has_grad_out) + if (!in.has_grad_in && !in.has_grad_out) continue; - ad_accum_grad(m_state2[i], (uint32_t) m_state[offset]); - - if (!in.has_grad_in) + if (in.has_grad_out && (m_state2[i] >> 32)) { + ad_accum_grad(m_state2[i], (uint32_t) m_state[offset]); ad_enqueue(dr::ADMode::Backward, m_state2[i]); + } else if (in.has_grad_in) { + ad_accum_grad(m_state2[i], (uint32_t) m_state[offset]); + } offset++; } @@ -771,18 +805,23 @@ struct LoopOp : public dr::detail::CustomOpBase { ad_traverse(dr::ADMode::Backward, (uint32_t) dr::ADFlag::ClearNone); // Read the loop output + derivatives copy to loop state vars - m_state.release(); - for (size_t i = 0; i < m_inputs.size(); ++i) - m_state.push_back_borrow((uint32_t) m_state2[i]); + for (size_t i = 0; i < m_inputs.size(); ++i) { + jit_var_inc_ref((uint32_t) m_state2[i]); + ad_var_dec_ref((uint32_t) m_state[i]); + m_state[i] = (uint32_t) m_state2[i]; + } offset = m_inputs.size(); for (size_t i = 0; i < m_inputs.size(); ++i) { const Input &in = m_inputs[i]; - - if (!in.has_grad_out) + if (!in.has_grad_in && !in.has_grad_out) continue; - uint32_t grad = ad_grad(m_state2[i]); - m_state.push_back_steal(grad); + + if (in.has_grad_in) { + ad_var_dec_ref(m_state[offset]); + m_state[offset] = ad_grad(m_state2[i]); + } + offset++; } @@ -796,10 +835,17 @@ struct LoopOp : public dr::detail::CustomOpBase { for (const Input &i : m_inputs) m_state.push_back_borrow(i.index); + uint32_t index = 0; for (const Input &in : m_inputs) { - uint32_t grad; - if (!in.has_grad_out) + index++; + if (!in.has_grad_out && !in.has_grad_in) continue; + uint32_t grad; + + if (in.has_grad_in && in.has_grad_out) + jit_raise("LoopOp::backward_simple(): unsupported " + "configuration. Variable %u (r%u) is marked both as a " + "differentiable output and an input.", index, (uint32_t) in.index); if (in.has_grad_in) { uint64_t zero = 0; @@ -818,29 +864,30 @@ struct LoopOp : public dr::detail::CustomOpBase { [](void *p) { return ((LoopOp *) p)->fwd_cond(); }, [](void *p) { return ((LoopOp *) p)->bwd_body_simple(); }, nullptr, false); - size_t offset = m_inputs.size(); for (const Input &in : m_inputs) { - if (!in.has_grad_out) + if (!in.has_grad_out && !in.has_grad_in) continue; - if (in.has_grad_in) { - ad_accum_grad(combine(m_input_indices[in.grad_in_index]), + if (in.has_grad_out) + ad_accum_grad(combine(m_output_indices[in.grad_out_offset]), (uint32_t) m_state[offset]); - } offset++; } + offset = m_inputs.size(); for (size_t i = 0; i < m_inputs.size(); ++i) { const Input &in = m_inputs[i]; - if (!in.has_grad_out) + if (!in.has_grad_out && !in.has_grad_in) continue; - ad_accum_grad(combine(m_output_indices[in.grad_out_offset]), - (uint32_t) m_state[m_inputs.size() + in.grad_in_offset]); - } + if (in.has_grad_in) + ad_accum_grad(combine(m_input_indices[in.grad_in_index]), + (uint32_t) m_state[offset]); + offset++; + } m_state.release(); } @@ -968,8 +1015,7 @@ bool ad_loop(JitBackend backend, int symbolic, int compress, vt != VarType::Float64) continue; - if (max_iterations == 0 && - (uint32_t) indices_in[i] == (uint32_t) indices_out[i]) { + if ((uint32_t) indices_in[i] == (uint32_t) indices_out[i]) { // Keep unchanged variables out of the AD system if (indices_in[i] != indices_out[i]) { ad_var_inc_ref(indices_in[i]); diff --git a/src/extra/math.cpp b/src/extra/math.cpp index ccad1c88..2168cbd0 100644 --- a/src/extra/math.cpp +++ b/src/extra/math.cpp @@ -93,7 +93,6 @@ DEFINE_MATH_OP(acos) DEFINE_MATH_OP(atan) DEFINE_MATH_OP(sinh) DEFINE_MATH_OP(cosh) -DEFINE_MATH_OP(tanh) DEFINE_MATH_OP(asinh) DEFINE_MATH_OP(acosh) DEFINE_MATH_OP(atanh) @@ -236,3 +235,23 @@ DRJIT_EXTRA_EXPORT uint32_t jit_var_cos(uint32_t i0) { return 0; } } + +DRJIT_EXTRA_EXPORT uint32_t jit_var_tanh(uint32_t i0) { + VarInfo info = jit_set_backend(i0); + + switch (info.type) { + case VarType::Float16: + return dr::tanh(Float16::borrow(i0)).release(); + + case VarType::Float32: + if (info.backend == JitBackend::CUDA) + return jit_var_tanh_intrinsic(i0); + return dr::tanh(Float32::borrow(i0)).release(); + + case VarType::Float64: + return dr::tanh(Float64::borrow(i0)).release(); + default: + jit_fail("jit_var_tanh(): invalid operand!"); + return 0; + } +} diff --git a/src/python/CMakeLists.txt b/src/python/CMakeLists.txt index 9886323d..678ac1fb 100644 --- a/src/python/CMakeLists.txt +++ b/src/python/CMakeLists.txt @@ -21,7 +21,7 @@ configure_file( ) set(PY_FILES - config.py __init__.py ast.py detail.py interop.py dda.py opt.py + config.py __init__.py ast.py detail.py interop.py dda.py opt.py nn.py _sh_eval.py _reduce.py scalar/__init__.py llvm/__init__.py llvm/ad.py cuda/__init__.py cuda/ad.py) @@ -83,6 +83,7 @@ nanobind_add_module( tracker.h tracker.cpp local.h local.cpp resample.h resample.cpp + coop_vec.h coop_vec.cpp # Backends scalar.h scalar.cpp @@ -211,6 +212,13 @@ if (NOT (DRJIT_SANITIZE_ASAN OR DRJIT_SANITIZE_UBSAN)) ${STUB_ARGS} ) + nanobind_add_stub( + drjit-stub-nn + MODULE drjit.nn + OUTPUT ${DRJIT_PYTHON_DST_DIR}/nn.pyi + ${STUB_ARGS} + ) + nanobind_add_stub( drjit-stub-scalar MODULE drjit.scalar diff --git a/src/python/apply.cpp b/src/python/apply.cpp index c611e099..fbfc0c47 100644 --- a/src/python/apply.cpp +++ b/src/python/apply.cpp @@ -446,7 +446,8 @@ NB_NOINLINE PyObject *apply_tensor(ArrayOp op, Slot slot, expanded_shapes_alloc[index] = vector(ndim, 1); vector& expanded_shape = expanded_shapes_alloc[index]; size_t offset = ndim - src_ndim; - memcpy(&expanded_shape[offset], shape->data(), sizeof(size_t) * src_ndim); + if (src_ndim) + memcpy(&expanded_shape[offset], shape->data(), sizeof(size_t) * src_ndim); return (const vector*)&expanded_shape; }; @@ -635,20 +636,18 @@ void traverse(const char *op, TraverseCallback &tc, nb::handle h) { } else if (tp.is(&PyDict_Type)) { for (nb::handle h2 : nb::borrow(h).values()) traverse(op, tc, h2); - } else { - if (nb::dict ds = get_drjit_struct(tp); ds.is_valid()) { - for (auto [k, v] : ds) - traverse(op, tc, nb::getattr(h, k)); - } else if (nb::object df = get_dataclass_fields(tp); df.is_valid()) { - for (nb::handle field : df) { - nb::object k = field.attr(DR_STR(name)); - traverse(op, tc, nb::getattr(h, k)); - } - } else if (nb::object cb = get_traverse_cb_ro(tp); cb.is_valid()) { - cb(h, nb::cpp_function([&](uint64_t index) { tc(index); })); - } else { - tc.traverse_unknown(h); + } else if (nb::dict ds = get_drjit_struct(tp); ds.is_valid()) { + for (auto [k, v] : ds) + traverse(op, tc, nb::getattr(h, k)); + } else if (nb::object df = get_dataclass_fields(tp); df.is_valid()) { + for (nb::handle field : df) { + nb::object k = field.attr(DR_STR(name)); + traverse(op, tc, nb::getattr(h, k)); } + } else if (nb::object cb = get_traverse_cb_ro(tp); cb.is_valid()) { + cb(h, nb::cpp_function([&](uint64_t index) { tc(index); })); + } else { + tc.traverse_unknown(h); } } catch (nb::python_error &e) { nb::raise_from(e, PyExc_RuntimeError, @@ -889,25 +888,23 @@ nb::object transform(const char *op, TransformCallback &tc, nb::handle h) { for (auto [k, v] : nb::borrow(h)) tmp[k] = transform(op, tc, v); result = std::move(tmp); - } else { - if (nb::dict ds = get_drjit_struct(tp); ds.is_valid()) { - nb::object tmp = tp(); - for (auto [k, v] : ds) - nb::setattr(tmp, k, transform(op, tc, nb::getattr(h, k))); - result = std::move(tmp); - } else if (nb::object df = get_dataclass_fields(tp); df.is_valid()) { - nb::object tmp = nb::dict(); - for (nb::handle field : df) { - nb::object k = field.attr(DR_STR(name)); - tmp[k] = transform(op, tc, nb::getattr(h, k)); - } - result = tp(**tmp); - } else if (nb::object cb = get_traverse_cb_rw(tp); cb.is_valid()) { - cb(h, nb::cpp_function([&](uint64_t index) { return tc(index); })); - result = nb::borrow(h); - } else if (!result.is_valid()) { - result = tc.transform_unknown(h); + } else if (nb::dict ds = get_drjit_struct(tp); ds.is_valid()) { + nb::object tmp = tp(); + for (auto [k, v] : ds) + nb::setattr(tmp, k, transform(op, tc, nb::getattr(h, k))); + result = std::move(tmp); + } else if (nb::object df = get_dataclass_fields(tp); df.is_valid()) { + nb::object tmp = nb::dict(); + for (nb::handle field : df) { + nb::object k = field.attr(DR_STR(name)); + tmp[k] = transform(op, tc, nb::getattr(h, k)); } + result = tp(**tmp); + } else if (nb::object cb = get_traverse_cb_rw(tp); cb.is_valid()) { + cb(h, nb::cpp_function([&](uint64_t index) { return tc(index); })); + result = nb::borrow(h); + } else if (!result.is_valid()) { + result = tc.transform_unknown(h); } return result; } catch (nb::python_error &e) { diff --git a/src/python/autodiff.cpp b/src/python/autodiff.cpp index 46f27fd0..e1f654ba 100644 --- a/src/python/autodiff.cpp +++ b/src/python/autodiff.cpp @@ -12,6 +12,7 @@ #include #include #include "autodiff.h" +#include "coop_vec.h" #include "apply.h" #include "meta.h" #include "init.h" @@ -43,6 +44,24 @@ static void set_grad_enabled(nb::handle h, bool enable_) { } } } + + void traverse_unknown(nb::handle h) override { + if (CoopVec *v = nullptr; nb::try_cast(h, v, false), v != nullptr) { + uint64_t index = v->m_index; + bool grad_enabled = ((uint32_t) index) != index; + if (enable != grad_enabled) { + if (enable) { + nb::raise( + "to create a differentiable cooperative vector, " + "construct it from grad-enabled components."); + } else { + jit_var_inc_ref((uint32_t) index); + ad_var_dec_ref(index); + v->m_index = (uint32_t) index; + } + } + } + } }; SetGradEnabled sge(enable_); @@ -88,6 +107,11 @@ bool grad_enabled(nb::handle h) { if (s.is_diff && is_float(s)) result |= ad_grad_enabled(s.index(inst_ptr(h))) != 0; } + + void traverse_unknown(nb::handle h) override { + if (CoopVec *v = nullptr; nb::try_cast(h, v, false), v != nullptr) + result |= ad_grad_enabled(v->m_index); + } }; GradEnabled ge; @@ -139,6 +163,15 @@ static nb::object detach(nb::handle h, bool preserve_type_ = true) { nb::inst_copy(h2, h1); } } + + nb::object transform_unknown(nb::handle h) const override { + if (CoopVec *v = nullptr; nb::try_cast(h, v, false), v != nullptr) { + uint32_t index = (uint32_t) v->m_index; + jit_var_inc_ref(index); + return nb::cast(CoopVec(index, v->m_size, v->m_type)); + } + return nb::borrow(h); + } }; if ((is_drjit_array(h) && !supp(h.type()).is_diff)) diff --git a/src/python/base.cpp b/src/python/base.cpp index 1cc02a86..f5b5761d 100644 --- a/src/python/base.cpp +++ b/src/python/base.cpp @@ -1261,7 +1261,7 @@ void export_base(nb::module_ &m) { m.def("power", [](Py_ssize_t arg0, Py_ssize_t arg1) { return std::pow(arg0, arg1); }, - doc_pow); + doc_power); m.def("power", [](double arg0, double arg1) { return std::pow(arg0, arg1); }); diff --git a/src/python/coop_vec.cpp b/src/python/coop_vec.cpp new file mode 100644 index 00000000..4f3a8d0c --- /dev/null +++ b/src/python/coop_vec.cpp @@ -0,0 +1,740 @@ +/* + src/coop_vec.cpp -- Python bindings for Cooperative CoopVecs + + Copyright (c) 2025 Wenzel Jakob + + All rights reserved. Use of this source code is governed by a BSD-style + license that can be found in the LICENSE file. +*/ + +#include "common.h" +#include "base.h" +#include "init.h" +#include "meta.h" +#include "apply.h" +#include "coop_vec.h" +#include +#include "nanobind/nanobind.h" +#include "nanobind/nb_defs.h" +#include +#include +#include +#include + + +/// Cooperative vector constructor +CoopVec::CoopVec(nb::handle arg) { + construct(arg); +} + +void CoopVec::construct(nb::handle arg) { + nb::handle single_arg = nb::none(); + if (nb::len(arg) == 1) + single_arg = arg[0]; + + if (CoopVec *v = nullptr; nb::try_cast(single_arg, v, false), v != nullptr) { + m_index = ad_var_inc_ref(v->m_index); + m_size = v->m_size; + m_type = v->m_type; + return; + } + + nb::handle arg_tp = single_arg.type(); + if (is_drjit_type(arg_tp)) { + const ArraySupplement &s = supp(arg_tp); + if (s.is_tensor) { + const dr::vector &shape = s.tensor_shape(inst_ptr(single_arg)); + if (shape.size() <= 2) { + construct(nb::list(single_arg)); + return; + } + } + } + + /// Flatten a PyTree into a set of 1D arrays used to construct a cooperative vector + struct Flatten: TraverseCallback { + std::vector result; + + void operator()(nb::handle h) { + if ((JitBackend) supp(h.type()).backend != JitBackend::None) + result.push_back(nb::borrow(h)); + } + + void traverse_unknown(nb::handle h) { + if (PyIter_Check(h.ptr())) + traverse("drjit.nn.CoopVec", *this, nb::list(h)); + else if (PyLong_CheckExact(h.ptr()) || PyFloat_CheckExact(h.ptr())) + result.push_back(nb::borrow(h)); + else + nb::raise("encountered an unknown type \"%s\"", nb::inst_name(h).c_str()); + } + }; + + Flatten cb; + traverse("drjit.nn.CoopVec", cb, arg); + + uint32_t size = (uint32_t) cb.result.size(); + + if (cb.result.empty()) + nb::raise("drjit.nn.CoopVec(): cannot be empty!"); + + // Identify type + for (uint32_t i = 0; i < size; ++i) { + nb::handle tp = cb.result[i].type(); + if (is_drjit_type(tp)) { + m_type = tp; + break; + } + } + + // Check that this type makes sense + if (!m_type.is_valid()) + nb::raise_type_error( + "drjit.nn.CoopVec(): at least one Jit-compiled 1D array is required as input " + "(e.g., of type 'drjit.cuda.Float16')!"); + + const ArraySupplement &s = supp(m_type); + if (s.ndim != 1 || (JitBackend) s.backend == JitBackend::None) + nb::raise_type_error( + "drjit.nn.CoopVec(): expected Jit-compiled 1D arrays as input " + "(e.g., of type 'drjit.cuda.Float16')!"); + + // Check/cast the other arguments + uint64_t *tmp = (uint64_t *) alloca(sizeof(uint64_t) * size); + for (uint32_t i = 0; i < size; ++i) { + nb::object value = cb.result[i]; + try { + if (!value.type().is(m_type)) { + value = m_type(value); + cb.result[i] = value; + } + tmp[i] = s.index(inst_ptr(value)); + } catch (...) { + nb::raise_type_error( + "drjit.nn.CoopVec.__init__(): encountered an incompatible " + "argument of type \"%s\" (expected \"%s\")!", + nb::inst_name(value).c_str(), + nb::type_name(m_type).c_str()); + } + } + + m_index = ad_coop_vec_pack(size, tmp); + m_size = size; +} + +/// Unpack a cooperative vector into a Python list +nb::list CoopVec::expand_to_list() const { + if (m_size == 0) + return nb::list(); + + uint64_t *tmp = (uint64_t *) alloca(m_size * sizeof(uint64_t)); + ad_coop_vec_unpack(m_index, m_size, tmp); + + nb::list result; + const ArraySupplement &s = supp(m_type); + for (uint32_t i = 0; i < m_size; ++i) { + nb::object o = nb::inst_alloc(m_type); + s.init_index(tmp[i], inst_ptr(o)); + ad_var_dec_ref(tmp[i]); + nb::inst_mark_ready(o); + result.append(std::move(o)); + } + return result; +} + +/// Unpack a cooperative vecotr into a Dr.Jit array type like CoopVecXf +nb::object CoopVec::expand_to_vector() const { + ArrayMeta m = supp(m_type); + m.ndim = 2; + m.shape[0] = DRJIT_DYNAMIC; + m.shape[1] = DRJIT_DYNAMIC; + return meta_get_type(m)(expand_to_list()); +} + +/// Perform one of several supported unary operations +template static CoopVec coop_vec_unary_op(const CoopVec &arg) { + if ((JitBackend) supp(arg.m_type).backend == JitBackend::LLVM) { + nb::object unpacked = arg.expand_to_vector(), func; + + switch (Op) { + case JitOp::Exp2: func = array_module.attr("exp2"); break; + case JitOp::Tanh: func = array_module.attr("tanh"); break; + case JitOp::Log2: func = array_module.attr("log2"); break; + default: + nb::raise("Unsupported operation!"); + } + + return CoopVec(func(unpacked)); + } + + return CoopVec( + ad_coop_vec_unary_op(Op, arg.m_index), + arg.m_size, + arg.m_type + ); +} + +/// Perform one of several supported binary operations +template +static nb::object coop_vec_binary_op(nb::handle h0, nb::handle h1) { + nb::object o[2] { nb::borrow(h0), nb::borrow(h1) }; + CoopVec *ptr[2] { }; + CoopVec *c = nullptr; + + for (uint32_t i = 0; i < 2; ++i) { + if (nb::try_cast(o[i], ptr[i], false)) + c = ptr[i]; + } + if (!c) + return nb::steal(NB_NEXT_OVERLOAD); + + for (uint32_t i = 0; i < 2; ++i) { + if (ptr[i]) + continue; + + nb::list args; + nb::object oi = c->m_type(o[i]); + for (uint32_t j = 0; j < c->m_size; ++j) + args.append(oi); + + o[i] = nb::cast(CoopVec(nb::borrow(nb::tuple(args)))); + if (!nb::try_cast(o[i], ptr[i], false)) + nb::raise("CoopVec::binary_op(): internal error"); + } + + return nb::cast(CoopVec( + ad_coop_vec_binary_op( + Op, + ptr[0]->m_index, + ptr[1]->m_index + ), + c->m_size, + c->m_type + )); +} + +/// Perform a ternary operation (currently only FMA) +template +static nb::object coop_vec_ternary_op(nb::handle h0, nb::handle h1, + nb::handle h2) { + nb::object o[3] { nb::borrow(h0), nb::borrow(h1), nb::borrow(h2) }; + CoopVec *ptr[3] { }; + CoopVec *c = nullptr; + + for (uint32_t i = 0; i < 3; ++i) { + if (nb::try_cast(o[i], ptr[i], false)) + c = ptr[i]; + } + if (!c) + return nb::steal(NB_NEXT_OVERLOAD); + + for (uint32_t i = 0; i < 3; ++i) { + if (ptr[i]) + continue; + + nb::list args; + for (uint32_t j = 0; j < c->m_size; ++j) + args.append(c->m_type(o[i])); + + o[i] = nb::cast(CoopVec(nb::borrow(nb::tuple(args)))); + if (!nb::try_cast(o[i], ptr[i], false)) + nb::raise("CoopVec::ternary_op(): internal error"); + } + + return nb::cast(CoopVec( + ad_coop_vec_ternary_op( + Op, + ptr[0]->m_index, + ptr[1]->m_index, + ptr[2]->m_index + ), + c->m_size, + c->m_type + )); +} + +/// Matrix-vector product +static CoopVec matvec(const MatrixView &A, + const CoopVec &x, + std::optional b, + bool transpose) { + + return { + ad_coop_vec_matvec( + A.index(), + &A.descr, + x.m_index, + b.has_value() ? b.value()->index() : 0, + b.has_value() ? &b.value()->descr : nullptr, + ((int) transpose) ^ ((int) A.transpose) + ), + transpose ? A.descr.cols : A.descr.rows, + x.m_type + }; +} + +nb::str MatrixView::repr() const { + const char *layout; + switch (descr.layout) { + case MatrixLayout::InferencingOptimal: layout = "inference"; break; + case MatrixLayout::TrainingOptimal: layout = "training"; break; + case MatrixLayout::RowMajor: layout = "row_major"; break; + default: layout = "unknown"; break; + } + return nb::str( + "drjit.nn.MatrixView[\n" + " dtype={},\n" + " layout=\"{}\",\n" + " shape=({}, {}),\n" + " stride={},\n" + " offset={}\n" + " size={}\n" + " buffer=<{} instance>\n" + "]" + ).format( + descr.dtype, + layout, + descr.rows, + descr.cols, + descr.stride, + descr.offset, + descr.size, + inst_name(buffer) + ); +} + +uint64_t MatrixView::index() const { + return supp(buffer.type()).index(inst_ptr(buffer)); +} + +MatrixView MatrixView::getitem(nb::object arg) const { + nb::object s[2]; + + if (descr.layout == MatrixLayout::InferencingOptimal || + descr.layout == MatrixLayout::TrainingOptimal) + nb::raise("drjit.MatrixView.__getitem__(): slicing is not permitted for " + "training/inferencing-optimal layouts!"); + + if (nb::isinstance(arg)) { + size_t l = nb::len(arg); + if (l == 0 || l > 2) + nb::raise("drjit.MatrixView.__getitem__(): expected 1 or 2 terms in " + "slice expression (got %zu)!", l); + s[0] = arg[0]; + if (l == 2) + s[1] = arg[1]; + } else { + s[0] = arg; + } + + if (!s[1].is_valid()) + s[1] = nb::slice(nb::none(), nb::none(), nb::none()); + + Py_ssize_t start[2], step[2]; + size_t len[2]; + + for (uint32_t i = 0; i < 2; ++i) { + uint32_t value; + if (nb::try_cast(s[i], value, false)) + s[i] = nb::slice(nb::int_(value), nb::int_(value + 1), nb::int_(1)); + nb::slice sl; + if (!nb::try_cast(s[i], sl, false)) + nb::raise("drjit.MatrixView.__getitem__(): expected 'int' or 'slice' " + "in slice expression, got '%s'!", + nb::inst_name(s[i]).c_str()); + size_t limit = i == 0 ? descr.rows : descr.cols; + auto [start_i, stop_i, step_i, len_i] = + sl.compute(limit); + start[i] = start_i; step[i] = step_i; len[i] = len_i; + } + + if (step[1] != 1) + nb::raise("drjit.MatrixView.__getitem__(): rows elements must be contiguous!"); + + if (len[0] == 0 || len[1] == 0) + nb::raise("drjit.MatrixView.__getitem__(): input array may not be empty!"); + + MatrixView result; + result.descr.rows = len[0]; + result.descr.cols = len[1]; + result.descr.offset = descr.offset + start[0] * descr.stride + start[1]; + result.descr.dtype = descr.dtype; + result.descr.layout = descr.layout; + result.descr.stride = descr.stride * step[0]; + result.descr.size = (len[0] - 1) * result.descr.stride + len[1]; + result.buffer = buffer; + return result; +} + +static MatrixView view(nb::handle_t arg) { + MatrixView result { }; + MatrixDescr &d = result.descr; + + const ArraySupplement &s = supp(arg.type()); + + d.dtype = (VarType) s.type; + d.layout = MatrixLayout::RowMajor; + + if (s.is_tensor) { + const dr::vector &shape = s.tensor_shape(inst_ptr(arg)); + if (shape.size() != 1 && shape.size() != 2) + nb::raise("drjit.view(): tensor must have 1 or 2 dimensions!"); + d.rows = shape[0]; + d.cols = shape.size() > 1 ? shape[1] : 1; + result.buffer = nb::steal(s.tensor_array(arg.ptr())); + } else if (s.ndim == 1 && s.shape[0] == DRJIT_DYNAMIC) { + d.rows = nb::len(arg); + d.cols = 1; + result.buffer = nb::borrow(arg); + } else { + nb::raise("Unsupported input type!"); + } + + d.stride = d.cols; + d.size = d.rows * d.cols; + d.offset = 0; + + if (d.rows == 0 || d.cols == 0) + nb::raise("drjit.view(): input array/tensor may not be empty!"); + + return result; +} + +struct RepackItem { + nb::object in_o; + nb::object out_o; + MatrixView *in; + MatrixView *out; + + RepackItem(nb::handle in_o, nb::handle out_o, MatrixView *in, MatrixView *out) + : in_o(nb::borrow(in_o)), out_o(nb::borrow(out_o)), in(in), out(out) { } + RepackItem(RepackItem&&) = default; + RepackItem(const RepackItem&) = default; +}; + +nb::handle view_type; +nb::handle coop_vector_type; + +static nb::object repack_impl(const char *name, MatrixLayout layout, + nb::handle arg_, uint32_t &offset, + std::vector &items) { + nb::handle arg_tp = arg_.type(); + nb::object arg = nb::borrow(arg_); + + if (is_drjit_type(arg_tp) && layout != MatrixLayout::RowMajor) { + arg = nb::cast(view(nb::handle_t(arg))); + arg_tp = view_type; + } + + if (arg_tp.is(view_type)) { + MatrixView *in_view = nb::cast(arg, false); + uint64_t in_index = supp(in_view->buffer.type()).index(inst_ptr(in_view->buffer)); + MatrixDescr out_descr = + jit_coop_vec_compute_layout(in_index, &in_view->descr, layout, offset); + MatrixView *out_view = new MatrixView{out_descr, nb::none()}; + nb::object result = nb::cast(out_view, nb::rv_policy::take_ownership); + items.emplace_back(arg, result, in_view, out_view); + offset = out_descr.offset + out_descr.size; + return result; + } else if (arg_tp.is(&PyTuple_Type)) { + nb::tuple t = nb::borrow(arg); + nb::list result; + for (nb::handle h : t) + result.append(repack_impl(name, layout, h, offset, items)); + return nb::tuple(result); + } else if (arg_tp.is(&PyList_Type)) { + nb::list l = nb::borrow(arg); + nb::list result; + for (nb::handle h : l) + result.append(repack_impl(name, layout, h, offset, items)); + return std::move(result); + } else if (arg_tp.is(&PyDict_Type)) { + nb::dict d = nb::borrow(arg); + nb::dict result; + for (auto [k, v] : d) + result[k] = repack_impl(name, layout, v, offset, items); + return std::move(result); + } else if (nb::dict ds = get_drjit_struct(arg_tp); ds.is_valid()) { + nb::object tmp = arg_tp(); + for (auto [k, v] : ds) + nb::setattr(tmp, k, repack_impl(name, layout, nb::getattr(arg, k), offset, items)); + return tmp; + } else if (nb::object df = get_dataclass_fields(arg_tp); df.is_valid()) { + nb::object tmp = nb::dict(); + for (nb::handle field : df) { + nb::object k = field.attr(DR_STR(name)); + tmp[k] = repack_impl(name, layout, nb::getattr(arg, k), offset, items); + } + return arg_tp(**tmp); + } else { + return nb::borrow(arg); + } +} + +static std::pair repack(const char *name, const char *layout_str, nb::handle arg) { + uint32_t offset = 0; + std::vector items; + MatrixLayout layout; + + if (layout_str) { + if (strcmp(layout_str, "inference") == 0) + layout = MatrixLayout::InferencingOptimal; + else if (strcmp(layout_str, "training") == 0) + layout = MatrixLayout::TrainingOptimal; + else + nb::raise("drjit.%s(): 'mode' must equal \"inference\" or \"training\"!", name); + } else { + layout = MatrixLayout::RowMajor; + } + + nb::object result = repack_impl(name, layout, arg, offset, items); + nb::object buffer = nb::none(); + + if (items.size() > 0) { + nb::handle buf_cur = items[0].in->buffer, + buf_tp = buf_cur.type(); + + buffer = full("zeros", buf_tp, nb::int_(0), offset, true); + const ArraySupplement &s = supp(buf_tp); + + std::vector in, out; + in.reserve(items.size()); + out.reserve(items.size()); + + auto submit = [&] { + jit_coop_vec_pack_matrices( + (uint32_t) in.size(), + s.index(inst_ptr(buf_cur)), + in.data(), + s.index(inst_ptr(buffer)), + out.data() + ); + }; + + for (size_t i = 0; i < items.size(); ++i) { + nb::handle buf_i = items[i].in->buffer, + buf_i_tp = buf_i.type(); + + if (!buf_i_tp.is(buf_tp)) { + nb::raise_type_error( + "drjit.%s(): encountered different input formats (%s vs %s)", name, + nb::type_name(buf_tp).c_str(), + nb::type_name(buf_i_tp).c_str()); + } + + if (!buf_cur.is(buf_i)) { + submit(); + in.clear(); + out.clear(); + buf_cur = buf_i; + } + + items[i].out->buffer = buffer; + + in.push_back(items[i].in->descr); + out.push_back(items[i].out->descr); + } + + if (!in.empty()) + submit(); + } + + return { buffer, result }; +} + +static CoopVec coopvec_abs_workaround(nb::handle_t &v) { + nb::list result; + for (nb::handle h: v) + result.append(nb::steal(PyNumber_Absolute(h.ptr()))); + return CoopVec(result); +} + +void export_coop_vec(nb::module_ &m) { + nb::module_ nn = m.def_submodule("detail").def_submodule("nn"); + nn.attr("__name__") = "drjit.nn"; + + nn.attr("ArrayT") = nb::type_var("ArrayT", "bound"_a = "drjit.ArrayBase"); + for (const char *name : + { "T", "SelfT", "SelfCpT", "ValT", "ValCpT", "RedT", "PlainT", "MaskT" }) + nn.attr(name) = nb::type_var(name); + + coop_vector_type = nb::class_(nn, "CoopVec", nb::is_generic(), nb::sig("class CoopVec(typing.Generic[T])")) + .def(nb::init(), + nb::sig("def __init__(self, *args: typing.Unpack[typing.Tuple[typing.Union[drjit.ArrayBase[SelfT, SelfCpT, ValT, ValCpT, T, PlainT, MaskT], float, int], ...]]) -> None"), + doc_nn_CoopVec_init) + .def("__iter__", [](const CoopVec &v) { return iter(v.expand_to_list()); }, + nb::sig("def __iter__(self, /) -> typing.Iterator[T]")) + .def("__add__", &coop_vec_binary_op, + nb::sig("def __add__(self, arg: CoopVec[T] | T | float | int, /) -> CoopVec[T]")) + .def("__radd__", &coop_vec_binary_op, + nb::sig("def __radd__(self, arg: CoopVec[T] | T | float | int, /) -> CoopVec[T]")) + .def("__sub__", &coop_vec_binary_op, + nb::sig("def __sub__(self, arg: CoopVec[T] | T | float | int, /) -> CoopVec[T]")) + .def("__rsub__", &coop_vec_binary_op, + nb::sig("def __rsub__(self, arg: CoopVec[T] | T | float | int, /) -> CoopVec[T]")) + .def("__mul__", &coop_vec_binary_op, + nb::sig("def __mul__(self, arg: CoopVec[T] | T | float | int, /) -> CoopVec[T]")) + .def("__rmul__", &coop_vec_binary_op, + nb::sig("def __rmul__(self, arg: CoopVec[T] | T | float | int, /) -> CoopVec[T]")) + .def_prop_ro("index", [](const CoopVec &v) { return v.m_index; }) + .def_prop_ro("type", [](const CoopVec &v) { return v.m_type; }) + .def("__len__", [](const CoopVec &v) { return v.m_size; }) + .def("__abs__", &coopvec_abs_workaround) + .def("__repr__", + [](const CoopVec &v) { + return nb::str("drjit.nn.CoopVec[{}, shape=({}, {})]") + .format(nb::type_name(v.m_type), v.m_size, + jit_var_size(v.m_index)); + }); + + view_type = nb::class_(nn, "MatrixView", doc_nn_MatrixView) + .def(nb::init<>()) + .def("__repr__", &MatrixView::repr) + .def("__getitem__", &MatrixView::getitem, + nb::sig("def __getitem__(self, arg: typing.Union[int, slice, typing.Tuple[typing.Union[int, slice], typing.Union[int, slice]]]) -> MatrixView")) + .def_prop_rw("dtype", + [](MatrixView &v) { return v.descr.dtype; }, + [](MatrixView &v, VarType v2) { v.descr.dtype = v2; }) + .def_prop_rw("offset", + [](MatrixView &v) { return v.descr.offset; }, + [](MatrixView &v, uint32_t v2) { v.descr.offset = v2; }) + .def_prop_rw("stride", + [](MatrixView &v) { return v.descr.stride; }, + [](MatrixView &v, uint32_t v2) { v.descr.stride = v2; }) + .def_prop_rw("size", + [](MatrixView &v) { return v.descr.size; }, + [](MatrixView &v, uint32_t v2) { v.descr.size = v2; }) + .def_prop_rw("layout", + [](MatrixView &v) { + switch (v.descr.layout) { + case MatrixLayout::InferencingOptimal: return "inference"; + case MatrixLayout::TrainingOptimal: return "training"; + case MatrixLayout::RowMajor: return "row_major"; + default: return "unknown"; + } + }, + [](MatrixView &v, const char *s) { + if (strcmp(s, "inference") == 0) + v.descr.layout = MatrixLayout::InferencingOptimal; + else if (strcmp(s, "training") == 0) + v.descr.layout = MatrixLayout::TrainingOptimal; + else if (strcmp(s, "row_major") == 0) + v.descr.layout = MatrixLayout::RowMajor; + else + nb::raise("Unknown layout!"); + }, + nb::for_getter(nb::sig("def layout(self) -> typing.Literal['inference', 'training', 'row_major']")), + nb::for_setter(nb::sig("def layout(self, value: typing.Literal['inference', 'training', 'row_major']) -> None"))) + .def_prop_rw("transpose", + [](MatrixView &v) { return v.transpose; }, + [](MatrixView &v, bool v2) { v.transpose = v2; }) + .def_prop_rw("shape", + [](MatrixView &v) { + return std::make_pair(v.descr.rows, v.descr.cols); + }, + [](MatrixView &v, std::pair v2) { + v.descr.rows = v2.first; + v.descr.cols = v2.second; + }) + .def("__matmul__", [](const MatrixView &self, const CoopVec &x) { return matvec(self, x, {}, false); }, + nb::sig("def __matmul__(self, arg: CoopVec[T], /) -> CoopVec[T]")) + .def_rw("buffer", &MatrixView::buffer) + .def_prop_ro("T", + [](MatrixView &v) { + MatrixView r; + r.descr = v.descr; + r.buffer = v.buffer; + r.transpose = !v.transpose; + return r; + }) + .def_prop_ro("grad", + [](MatrixView &v) { + MatrixView r; + r.descr = v.descr; + r.buffer = v.buffer.attr("grad"); + r.transpose = v.transpose; + return r; + }); + + + nb::dict drjit_struct; + drjit_struct["layout"] = nb::handle(&PyUnicode_Type); + drjit_struct["buffer"] = nb::none(); + drjit_struct["dtype"] = nb::type(); + drjit_struct["shape"] = nb::handle(&PyTuple_Type); + drjit_struct["offset"] = nb::handle(&PyLong_Type); + drjit_struct["size"] = nb::handle(&PyLong_Type); + drjit_struct["stride"] = nb::handle(&PyLong_Type); + drjit_struct["transpose"] = nb::handle(&PyBool_Type); + view_type.attr("DRJIT_STRUCT") = drjit_struct; + + nn.def("view", &view, + doc_nn_view); + + nn.def("pack", [](nb::handle arg, const char *layout) { return repack("pack", layout, arg); }, + nb::arg(), "layout"_a = "inference", + nb::sig("def pack(arg: MatrixView | drjit.AnyArray, *, layout: typing.Literal['inference', 'training'] = 'inference') -> typing.Tuple[drjit.ArrayBase, MatrixView]"), + doc_nn_pack); + + nn.def("pack", + [](nb::args args, const char *layout) { + auto temp = repack("pack", layout, args); + nb::list l; + l.append(temp.first); + l.extend(temp.second); + return nb::tuple(l); + }, + "args"_a, "layout"_a = "inference", + nb::sig("def pack(*args: PyTree, layout: typing.Literal['inference', " + "'training'] = 'inference') -> typing.Tuple[drjit.ArrayBase, " + "typing.Unpack[typing.Tuple[PyTree, ...]]]")); + + nn.def("unpack", [](nb::handle arg) { + return repack("unpack", nullptr, arg); }, + nb::sig("def unpack(arg: MatrixView | drjit.AnyArray, /) -> typing.Tuple[drjit.ArrayBase, MatrixView]"), + doc_nn_unpack); + + nn.def("unpack", + [](nb::args args) { + auto temp = repack("unpack", nullptr, args); + nb::list l; + l.append(temp.first); + l.extend(temp.second); + return nb::tuple(l); + }, + "args"_a, + nb::sig("def unpack(*args: PyTree) -> typing.Tuple[drjit.ArrayBase, " + "typing.Unpack[typing.Tuple[PyTree, ...]]]")); + + nn.def("matvec", &matvec, "A"_a.noconvert(), "x"_a.noconvert(), + "b"_a.noconvert() = nb::none(), "transpose"_a = false, + nb::sig("def matvec(A: MatrixView, x: drjit.nn.CoopVec[T], b: typing.Optional[MatrixView] = " + "None, /, transpose: bool = False) -> drjit.nn.CoopVec[T]"), + doc_nn_matvec); + + nn.def("cast", + [](CoopVec vec, nb::type_object_t tp) { + const ArraySupplement &s = supp(tp); + ArrayMeta m = supp(vec.m_type); + m.type = s.type; + nb::handle new_type = meta_get_type(m); + return CoopVec(ad_coop_vec_cast(vec.m_index, (VarType) s.type), + vec.m_size, new_type); + }, nb::sig("def cast(arg0: CoopVec[T], arg1: typing.Type[ArrayT], /) -> CoopVec[ArrayT]"), + doc_nn_cast + ); + + m.def("fma", &coop_vec_ternary_op); + m.def("minimum", &coop_vec_binary_op); + m.def("maximum", &coop_vec_binary_op); + m.def("step", &coop_vec_binary_op, doc_step); + m.def("log2", &coop_vec_unary_op); + m.def("exp2", &coop_vec_unary_op); + m.def("tanh", &coop_vec_unary_op); + m.def("step", [](nb::handle h0, nb::handle h1) { + return select( + nb::steal(PyObject_RichCompare(h0.ptr(), h1.ptr(), Py_LT)), + nb::int_(0), nb::int_(1)); + }); + m.def("abs", coopvec_abs_workaround); +} diff --git a/src/python/coop_vec.h b/src/python/coop_vec.h new file mode 100644 index 00000000..47b4e076 --- /dev/null +++ b/src/python/coop_vec.h @@ -0,0 +1,83 @@ +/* + src/coop_vec.h -- Python bindings for Cooperative CoopVecs + + Copyright (c) 2025 Wenzel Jakob + + All rights reserved. Use of this source code is governed by a BSD-style + license that can be found in the LICENSE file. +*/ + +#include "common.h" + +extern void export_coop_vec(nb::module_ &m); + +/// Cooperative vector container data structure +struct CoopVec { + /// JIT variable ID + uint64_t m_index = 0; + /// Number of entries + uint32_t m_size = 0; + /// Element type + nb::handle m_type; + + CoopVec(nb::handle arg); + + /// Steals ownership of 'index' + CoopVec(uint64_t index, uint32_t size, nb::handle type) + : m_index(index), m_size(size), m_type(type) { } + + /// Copy constructor + CoopVec(const CoopVec &vec) + : m_index(vec.m_index), m_size(vec.m_size), m_type(vec.m_type) { + ad_var_inc_ref(m_index); + } + CoopVec(CoopVec &&vec) noexcept + : m_index(vec.m_index), m_size(vec.m_size), m_type(vec.m_type) { + vec.m_index = 0; + vec.m_size = 0; + vec.m_type = nb::handle(); + } + CoopVec& operator=(CoopVec &&x) { + ad_var_dec_ref(m_index); + m_index = x.m_index; + m_size = x.m_size; + m_type = x.m_type; + x.m_index = x.m_size = 0; + x.m_type = nb::handle(); + return *this; + } + ~CoopVec() { ad_var_dec_ref(m_index); } + + /// Expand a cooperative vector into a Python list + nb::list expand_to_list() const; + + /// Expand a cooperative vector into a Dr.Jit array type (e.g. ArrayXf) + nb::object expand_to_vector() const; + +private: + void construct(nb::handle arg); +}; + +/// Shared view into a matrix +struct MatrixView { + /// Shape, strides, etc. + MatrixDescr descr{}; + + /// Dr.Jit 1D array holding the data + nb::object buffer; + + /// Should the view be transposed? + bool transpose = false; + + MatrixView() = default; + MatrixView(const MatrixDescr &descr, const nb::handle &buffer) + : descr(descr), buffer(nb::borrow(buffer)), transpose(false) { } + + nb::str repr() const; + MatrixView getitem(nb::object arg) const; + uint64_t index() const; +}; + + +extern nb::handle view_type; +extern nb::handle coop_vector_type; diff --git a/src/python/detail.cpp b/src/python/detail.cpp index ef9f27e9..d9fc2cb2 100644 --- a/src/python/detail.cpp +++ b/src/python/detail.cpp @@ -307,7 +307,14 @@ void export_detail(nb::module_ &) { []() { int major, minor, patch; jit_llvm_version(&major, &minor, &patch); - return nb::str("{}.{}.{}").format(major, minor, patch); + return nb::make_tuple(major, minor, patch); + }) + + .def("cuda_version", + []() { + int major, minor; + jit_cuda_version(&major, &minor); + return nb::make_tuple(major, minor); }) .def("trace_func", &trace_func, "frame"_a, "event"_a, diff --git a/src/python/dlpack.cpp b/src/python/dlpack.cpp index a0ac00ed..b96f7543 100644 --- a/src/python/dlpack.cpp +++ b/src/python/dlpack.cpp @@ -252,6 +252,10 @@ void export_dlpack(nb::module_ &) { [](nb::handle_t h) { return nb::ndarray(dlpack(h, true).handle()); }, doc_array) + .def("to_numpy", // needed for Matplotlib + [](nb::handle_t h) { + return nb::ndarray(dlpack(h, true).handle()); + }, doc_array) .def("torch", [](nb::handle_t h) { nb::module_ torch = nb::module_::import_("torch.utils.dlpack"); diff --git a/src/python/docstr.rst b/src/python/docstr.rst index 0046499e..e06df9fd 100644 --- a/src/python/docstr.rst +++ b/src/python/docstr.rst @@ -579,7 +579,7 @@ Returns: object: The result of the operation ``arg*arg`` -.. topic:: pow +.. topic:: power Raise the first argument to a power specified via the second argument. @@ -591,7 +591,7 @@ reduces operation to a sequence of multiplies and adds (potentially followed by a reciprocation operation when ``arg1`` is negative). - The general case involves recursive use of the identity ``pow(arg0, arg1) = + The general case involves recursive use of the identity ``power(arg0, arg1) = exp2(log2(arg0) * arg1)``. There is no difference between using :py:func:`drjit.power()` and the builtin @@ -7311,7 +7311,7 @@ `__. The operation is a no-op when no profile collection tool is attached. - Note the difference between this context manager and :py:ref:`dr.profile_enable() + Note the difference between this context manager and :py:func:`dr.profile_enable() `, which enables targeted profiling of a smaller region of code (as opposed to profiling the entire program). @@ -7328,7 +7328,7 @@ code_to_be_profiled() Note the difference between this context manager and - :py:ref:`dr.profile_range() `, which annotates a profiled + :py:func:`dr.profile_range() `, which annotates a profiled region with a label. .. topic:: ReduceMode @@ -7503,7 +7503,7 @@ >>> from drjit.llvm.ad import TensorXf >>> value = dr.arange(TensorXf, 6) - >>> dr.reshape(dtype=TensorXf, value=value, shape=(3, -1)) + >>> dr.reshape(value, (3, -1)) [[0, 1] [2, 3] [4, 5]] @@ -7511,16 +7511,17 @@ 2. **Reshaping nested arrays**: The function can ravel and unravel nested arrays (which have some static dimensions). This provides a high-level interface that subsumes the functions :py:func:`drjit.ravel` and - :py:func:`drjit.unravel`. + :py:func:`drjit.unravel`. In this case, the target ``dtype`` must be + specified: .. code-block:: pycon >>> from drjit.llvm.ad import Array2f, Array3f >>> value = Array2f([1, 2, 3], [4, 5, 6]) - >>> dr.reshape(dtype=Array3f, value=value, shape=(3, -1), order='C') + >>> dr.reshape(Array3f, value, shape=(3, -1), order='C') [[1, 4, 2], [5, 3, 6]] - >>> dr.reshape(dtype=Array3f, value=value, shape=(3, -1), order='F') + >>> dr.reshape(Array3f, value, shape=(3, -1), order='F') [[1, 3, 5], [2, 4, 6]] @@ -7604,11 +7605,12 @@ f'{size} elements in a queue of size {queue_size}') # Reshape the queue and re-run the loop - state = dr.reshape(dtype=type(state), value=queue, shape=size, shrink=True) + state = dr.reshape(queue, shape=size, shrink=True) Args: dtype (type): Desired output type of the reshaped array. This could equal ``type(value)`` or refer to an entirely different array type. + Must only be specified if the target dtype is different. value (object): An arbitrary Dr.Jit array, tensor, or :ref:`PyTree `. The function returns unknown objects of other types @@ -8104,3 +8106,181 @@ .. topic:: leak_warnings Query whether leak warnings are enabled. See :py:func:`drjit.detail.set_leak_warnings()`. + +.. topic:: step + + Step function. + + This function generates a step function by comparing ``arg0`` to ``arg1``. + The function is equivalent to + + .. code-block:: python + + dr.select( + arg0 < arg1, + 0, # if arg0 < arg1 + 1, # if arg1 >= arg1 + ) + + Args: + arg0 (object): A Dr.Jit array/tensor or Python arithmetic type + + arg1 (object): A Dr.Jit array/tensor or Python arithmetic type + + Returns: + object: The computed array as described above + +.. topic:: nn_CoopVec + + A *cooperative vector* is a dynamically-sized container of elements of a + consistent type. It admits both floating point and integer 1D arrays as + elements (e.g., :py:class:`drjit.cuda.Float16`, + :py:class:`drjit.llvm.UInt32`). Cooperative vectors primarily exist to + enable the compilation of expressions that make use of matrix-vector + multiplication. + + Seen from a high level, cooperative vectors resemble nested array types, + such as as :py:class:`drjit.cuda.ArrayXf16`. A variety of conversions + between cooperative vectors and regular Dr.Jit arrays are possible. + + .. code-block:: python + + # Pack individual components into a cooperative vector + vec = drjit.nn.CoopVec(x, y, z) + + # Unpack components + x, y, z = vec + + # Unpack directly into 3D array + xyz = Array3f(vec) + + # Convert a 3D array and a 2D array into a 5D cooperative vector + a1: Array3f = ... + a2: Array2f = ... + vec = drjit.nn.CoopVec(a1, a2) + + The main difference between regular Dr.Jit arrays and cooperative vectors is + that they *do not permit indexed element access*. For example, the following + operation raises an Exception: + + .. code-block:: pycon + + >>> vec = drjit.nn.CoopVec(x, y, z) + >>> vec[1] + Traceback (most recent call last): + File "", line 1, in + TypeError: 'drjit.nn.CoopVec' object is not subscriptable + + The compilation stack may arbitrarily redistribute the elements of a + cooperative vector across threads for efficiency (this is what + *cooperative* refers to). Indexed access to a cooperative vector's elements + would interfere with such optimizations. + + To unpack a cooperative vector into its components, use an expression + like ``x, y, z = vec``, ``ArrayXf(vec)``, or ``list(vec)``. + +.. topic:: nn_CoopVec_init + + The constructor accepts a variable number of arguments including Dr.Jit + arrays, scalar Python integers and floating point values, and :ref:`PyTrees + `. It flattens this input into a list of vector components. + + At least one Jit-compiled array must be provided as input so that Dr.Jit can + infer the cooperative vector's element type. An exception will be raised if + the input contains Dr.Jit arrays of inconsistent scalar types (e.g., + :py:class:`drjit.cuda.Array2f` and :py:class:`drjit.cuda.UInt`). + +.. topic:: nn_MatrixView + + The :py:class:`drjit.nn.MatrixView` provides pointer into a buffer along with + shape and type metadata. + + Dr.Jit uses views to tightly pack sequences of matrices and bias vectors + into a joint buffer, and to preserve information about the underlying data + type and layout. The :py:func:`__getitem__` function can be used to slice a + view into smaller sub-blocks. + + The typical process is to pack a PyTree of weight and bias vectors via + :py:func:`drjit.pack()` into an inference or training-optimal + representation. The returned views can then be passed to + :py:func:`drjit.nn.matvec()`. + +.. topic:: nn_view + + Convert a Dr.Jit array or tensor into a *view*. + + This function simply returns a view of the original tensor without + transforming the underlying representation. This is useful to + + - Use :py:func:`drjit.nn.matvec` with a row-major matrix layout (which, + however, is not recommended, since this can be significantly slower + compared to matrices in inference/training-optimal layouts). + + - Slice a larger matrix into sub-blocks before passing them to + :py:func:`drjit.nn.pack` (which also accepts *views* as inputs). + This is useful when several matrices are already packed into a single + matrix (which is, however, still in row-major layout). They can then be + directly re-packed into optimal layouts without performing further + unnecessary copies. + +.. topic:: nn_pack + + A training-optimal layout must be used used if the program *backpropagates* + (as in :py:func:`dr.backward*() `) gradients through + matrix-vector products. Inference (primal evaluation) and forward derivative + propagation (as in :py:func:`dr.forward*() `) does not + require a training-optimal layout. + + If the input matrices are already packed in a row-major layout, call + :py:func:`dr.nn.view() ` to create an efficient reference + and then pass slices of the view to :py:func:`dr.nn.pack() + `. This avoids additional copies. + + .. code-block:: + + mat: TensorXf = ... + mat_view = dr.nn.view(mat) + + A1_view, A2_view = dr.nn.pack( + mat_view[0:32, :], + mat_view[32:64, :] + ) + +.. topic:: nn_unpack + + The function :py:func:`dr.nn.unpack() ` transforms a + sequence (or :ref:`PyTree `) of vectors and optimal-layout matrices + back into row-major layout. + + .. code-block:: python + + A_out, b_out = dr.nn.unpack(A_opt, b_opt) + + Note that the output of this function are (row-major) *views* into a shared + buffer. Each view holds a reference to the shared buffer. Views can be + converted back into regular tensors: + + .. code-block:: python + + A = TensorXf16(A) + +.. topic:: nn_matvec + + Evaluate a matrix-vector multiplication involving a cooperative vector. + + This function takes a *matrix view* ``A`` (see :py:func:`drjit.nn.pack` + and :py:func:`drjit.nn.view` for details on views) and a *cooperative + vector* ``x``. It then computes the associated matrix-vector product and + returns it in the form of a new cooperative vector (potentially with a + different size). + + The function can optionally apply an additive bias (i.e., to evaluate ``A@x + + b``). This bias vector ``b`` should also be specified as a view. + + Specify ``tranpose=True`` to multiply by the transpose of the matrix ``A``. + On the CUDA/OptiX backend, this feature requires that ``A`` is in inference + or training-optimal layout. + +.. topic:: nn_cast + + Cast the numeric type underlying a cooperative vector diff --git a/src/python/eval.cpp b/src/python/eval.cpp index 01faa5e0..4bd5a2ff 100644 --- a/src/python/eval.cpp +++ b/src/python/eval.cpp @@ -11,6 +11,7 @@ #include "eval.h" #include "apply.h" #include "local.h" +#include "coop_vec.h" bool schedule(nb::handle h) { bool result_ = false; @@ -31,6 +32,8 @@ bool schedule(nb::handle h) { for (uint32_t index : local.arrays()) result |= (bool) jit_var_schedule(index); } + if (h.type().is(coop_vector_type)) + nb::raise("Cooperative vectors cannot be evaluated. They must be unpacked into regular variables."); } }; @@ -65,6 +68,16 @@ static void make_opaque(nb::handle h) { ad_var_dec_ref(index_new); } + + void traverse_unknown(nb::handle h) override { + if (h.type().is(local_type)) { + Local & local = nb::cast(h); + for (uint32_t index : local.arrays()) + result |= (bool) jit_var_schedule(index); + } + if (h.type().is(coop_vector_type)) + nb::raise("Cooperative vectors cannot be evaluated. They must be unpacked into regular variables."); + } }; ScheduleForceCallback sfc; diff --git a/src/python/init.cpp b/src/python/init.cpp index bb1e1b17..be422a40 100644 --- a/src/python/init.cpp +++ b/src/python/init.cpp @@ -12,12 +12,14 @@ #include #include #include "../ext/nanobind/src/buffer.h" +#include "drjit/python.h" #include "meta.h" #include "base.h" #include "memop.h" #include "shape.h" #include "dlpack.h" #include "init.h" +#include "coop_vec.h" #include /// Forward declaration @@ -134,7 +136,7 @@ int tp_init_array(PyObject *self, PyObject *args, PyObject *kwds) noexcept { } // Try to construct from an instance created by another - // array programming framework + // array programming framework or a Dr.Jit tensor nb::object converted_complex_scalar; if (is_drjit_tensor || (!arg_is_drjit && !is_builtin(arg_tp) && nb::ndarray_check(arg))) { // For scalar types we want to rely on broadcasting below @@ -142,14 +144,31 @@ int tp_init_array(PyObject *self, PyObject *args, PyObject *kwds) noexcept { // Import flattened array in C-style ordering nb::object flattened; - if (is_drjit_tensor) - flattened = nb::steal(supp(arg_tp).tensor_array(arg)); - else - flattened = import_ndarray(s, arg); - if (s.is_complex) do_flip_axes = true; + if (is_drjit_tensor) { + const ArraySupplement &as = supp(arg_tp); + const dr::vector &shape = as.tensor_shape(inst_ptr(arg)); + if (shape.size() != s.ndim) + nb::raise("dimensionality mismatch (target has %u, " + "source has %zu dimensions)", + s.ndim, shape.size()); + for (uint32_t d = 0; d < s.ndim; ++d) { + if (s.shape[d] == DRJIT_DYNAMIC) + continue; + size_t source_shape = + do_flip_axes ? shape[shape.size() - 1 - d] + : shape[d]; + if (s.shape[d] != source_shape) + nb::raise("mismatched shape (axis %u has size %u in target type, %zu in source tensor)", + d, s.shape[d], source_shape); + } + flattened = nb::steal(as.tensor_array(arg)); + } else { + flattened = import_ndarray(s, arg); + } + nb::object unraveled = unravel( nb::borrow>(self_tp), flattened, do_flip_axes ? 'F' : 'C'); @@ -645,6 +664,28 @@ static void ndarray_keep_alive(JitBackend backend, uint32_t index, nb::detail::n nb::object full_alt(nb::type_object dtype, nb::handle value, size_t size); nb::object empty_alt(nb::type_object dtype, size_t size); +nb::object view_to_tensor(nb::handle h, dr::vector &shape) { + MatrixView &view = nb::cast(nb::handle(h)); + if (view.transpose) + nb::raise("The view is transposed. Conversion into tensor format still " + "needs to be implemented."); + + if (view.descr.layout != MatrixLayout::RowMajor) + nb::raise("This tensor is in an inference/training-optimal layout. To " + "convert it back into tensor form, you must unpack it into a " + "row-major representation via drjit.nn.unpack()."); + + if (view.descr.stride != view.descr.cols) + nb::raise("Unsupported row stride: expected stride %u, found %u.", + view.descr.cols, view.descr.stride); + + shape.push_back(view.descr.rows); + shape.push_back(view.descr.cols); + + return view.buffer[nb::slice(view.descr.offset, + view.descr.offset + view.descr.size, 1u)]; +} + int tp_init_tensor(PyObject *self, PyObject *args, PyObject *kwds) noexcept { PyTypeObject *self_tp = Py_TYPE(self); @@ -660,7 +701,9 @@ int tp_init_tensor(PyObject *self, PyObject *args, PyObject *kwds) noexcept { bool do_flip_axes = flip_axes == Py_True; PyTypeObject *array_tp = array ? Py_TYPE(array) : nullptr; - raise_if(do_flip_axes && (shape || !array_tp || !is_drjit_type(array_tp) || + raise_if(do_flip_axes && (shape || !array_tp || + (!is_drjit_type(array_tp) && + !nb::handle(array_tp).is(coop_vector_type)) || array_tp == self_tp), "flip_axes=True requires that 'shape' is not specified, and " "that the input is a nested Dr.Jit array type (e.g. " @@ -676,6 +719,14 @@ int tp_init_tensor(PyObject *self, PyObject *args, PyObject *kwds) noexcept { // Same type -> copy constructor if (array_tp == self_tp) { + if (shape) + nb::raise( + "use 'Tensor(x.array, shape)' or 'drjit.reshape(Tensor, x, " + "shape)' to reshape a tensor"); + if (do_flip_axes) + nb::raise("The flip_axes argument is only supported when " + "constructing tensors from N-D arrays or cooperative " + "vectors"); nb::detail::nb_inst_copy(self, array); return 0; } @@ -690,6 +741,8 @@ int tp_init_tensor(PyObject *self, PyObject *args, PyObject *kwds) noexcept { // Try to construct from an instance created by another // array programming framework flat = import_ndarray(s, array, &shape_vec); + } else if (nb::isinstance(nb::handle(array))) { + flat = view_to_tensor(array, shape_vec); } else { // Infer the shape of an arbitrary data structure & flatten it VarType vt = (VarType) s.type; @@ -984,12 +1037,24 @@ nb::object linspace(const nb::type_object_t &dtype, if (size == 0) return dtype(); - nb::object result = nb::inst_alloc(counter_tp); - counter_s.init_counter((size_t) size, inst_ptr(result)); - nb::inst_mark_ready(result); + nb::object counter = nb::inst_alloc(counter_tp); + counter_s.init_counter((size_t) size, inst_ptr(counter)); + nb::inst_mark_ready(counter); + + nb::handle dtype_c = dtype; + if ((VarType) s.type == VarType::Float16) { + ArrayMeta m = s; + m.type = (uint16_t) VarType::Float32; + dtype_c = meta_get_type(m); + } double step = (stop - start) / (size - ((endpoint && size > 0) ? 1 : 0)); - return fma(dtype(result), dtype(step), dtype(start)); + nb::object result = fma(dtype_c(counter), dtype_c(step), dtype_c(start)); + + if (!dtype_c.is(dtype)) + result = dtype(result); + + return result; } /// Extract types from typing.Optional[T], typing.Union[T, None], etc. diff --git a/src/python/main.cpp b/src/python/main.cpp index 99712f3e..74fa3239 100644 --- a/src/python/main.cpp +++ b/src/python/main.cpp @@ -40,6 +40,7 @@ #include "tracker.h" #include "local.h" #include "resample.h" +#include "coop_vec.h" static int active_backend = -1; @@ -228,6 +229,7 @@ NB_MODULE(_drjit_ext, m_) { jit_init_async(backends); export_bind(detail); + export_coop_vec(m); export_base(m); export_init(m); export_shape(m); diff --git a/src/python/memop.cpp b/src/python/memop.cpp index 23d9c111..d0e3cc60 100644 --- a/src/python/memop.cpp +++ b/src/python/memop.cpp @@ -559,7 +559,8 @@ static void ravel_recursive(nb::handle result, nb::handle value, nb::object index = arange(nb::borrow>(index_dtype), offset, offset + strides[depth] * shape[depth], strides[depth]); - ::scatter(nb::borrow(result), nb::borrow(value), index, nb::cast(true)); + ::scatter(nb::borrow(result), nb::borrow(value), index, nb::cast(true), + ReduceMode::Permute); } else { result[offset] = value; } @@ -625,6 +626,26 @@ nb::object ravel(nb::handle h, char order, vt = (VarType) s.type; is_dynamic = s.shape[s.ndim - 1] == DRJIT_DYNAMIC; is_diff = s.is_diff; + } else if (nb::isinstance(h)) { + nb::object o = nb::borrow(h); + while (true) { + if (!nb::hasattr(o, "__len__") || nb::len(o) == 0) { + if (vt_in) + vt = (VarType) *vt_in; + break; + } + if (is_drjit_array(o)) { + const ArraySupplement &s = supp(o.type()); + backend = (JitBackend) s.backend; + vt = (VarType) s.type; + is_dynamic = s.ndim != 0 && s.shape[s.ndim - 1] == DRJIT_DYNAMIC; + is_diff = s.is_diff; + break; + } + o = o[0]; + } + } else if (nb::isinstance(h)) { + return ravel(nb::list(h), order, shape_out, strides_out, vt_in); } else if (vt_in) { vt = (VarType) *vt_in; } @@ -1065,6 +1086,18 @@ static nb::object reshape_2(nb::type_object dtype, nb::handle value, return reshape(dtype, value, shape_vec, order, shrink); } +static nb::object reshape_same_dtype(nb::handle value, + const dr::vector &target_shape, + char order, bool shrink) { + return reshape(nb::borrow(value.type()), value, target_shape, + order, shrink); +} + +static nb::object reshape_same_dtype_2(nb::handle value, Py_ssize_t shape, + char order, bool shrink) { + return reshape_2(nb::borrow(value.type()), value, shape, order, shrink); +} + static nb::object repeat_or_tile(nb::handle h, size_t count, bool tile) { struct RepeatOrTileOp : TransformCallback { size_t count; @@ -1149,6 +1182,10 @@ void export_memop(nb::module_ &m) { "shape"_a, "order"_a = 'A', "shrink"_a = false, doc_reshape) .def("reshape", &reshape_2, "dtype"_a, "value"_a, "shape"_a, "order"_a = 'A', "shrink"_a = false) + .def("reshape", &reshape_same_dtype, "value"_a, + "shape"_a, "order"_a = 'A', "shrink"_a = false, doc_reshape) + .def("reshape", &reshape_same_dtype_2, "value"_a, + "shape"_a, "order"_a = 'A', "shrink"_a = false) .def("tile", [](nb::handle h, size_t count) { return repeat_or_tile(h, count, true); diff --git a/src/python/meta.cpp b/src/python/meta.cpp index 940bafb8..ba3b962b 100644 --- a/src/python/meta.cpp +++ b/src/python/meta.cpp @@ -12,7 +12,6 @@ #include "base.h" #include "../ext/nanobind/src/buffer.h" #include -#include /// Check if the given metadata record is valid bool meta_check(ArrayMeta m) noexcept { diff --git a/src/python/random.h b/src/python/random.h index 0480f6ca..43d7799e 100644 --- a/src/python/random.h +++ b/src/python/random.h @@ -54,7 +54,7 @@ void bind_pcg32(nb::module_ &m) { } if (!key) - nb::raise_type_error("Invalid 'dtype'"); + nb::raise_type_error("PCG32.next_float(): invalid 'dtype'"); auto &&fn = self.attr(key); return !mask.is(Py_True) ? fn(mask) : fn(); diff --git a/src/python/reduce.cpp b/src/python/reduce.cpp index 17ed4dbb..9292d5cf 100644 --- a/src/python/reduce.cpp +++ b/src/python/reduce.cpp @@ -17,6 +17,7 @@ #include "init.h" #include "apply.h" #include "detail.h" +#include "coop_vec.h" #include using ReduceInit = nb::object(); @@ -542,6 +543,10 @@ nb::object dot(nb::handle h0, nb::handle h1) { } if (use_fma) { + if (tp0.is(coop_vector_type) || tp1.is(coop_vector_type)) { + nb::list o0 = nb::list(h0), o1 = nb::list(h1); + return dot(o1, o1); + } nb::object result = h0[0] * h1[0], fma = array_module.attr("fma"); for (size_t i = 1; i < lr; ++i) diff --git a/src/python/tracker.cpp b/src/python/tracker.cpp index ca68abf5..f974cb0c 100644 --- a/src/python/tracker.cpp +++ b/src/python/tracker.cpp @@ -17,6 +17,7 @@ #include "base.h" #include "local.h" #include "shape.h" +#include "coop_vec.h" #include #include #include @@ -332,6 +333,8 @@ bool VariableTracker::Impl::traverse(Context &ctx, nb::handle h) { ctx.label.c_str(), nb::inst_name(prev).c_str(), nb::type_name(tp).c_str()); + // Were there any external changes to sub-PyTree variable indices (as + // opposed to changes done by the VariableTracker) bool changed = false; if (is_drjit_type(tp)) { @@ -402,8 +405,7 @@ bool VariableTracker::Impl::traverse(Context &ctx, nb::handle h) { VarInfo vi = jit_set_backend((uint32_t) idx); if (new_variable) { - if (!v->index_orig) - v->index_orig = ad_var_inc_ref(idx); + v->index_orig = ad_var_inc_ref(idx); v->index = ad_var_inc_ref(idx); v->size = vi.size; } else { @@ -521,6 +523,41 @@ bool VariableTracker::Impl::traverse(Context &ctx, nb::handle h) { ScopedAppendLabel guard(ctx, "[", nb::repr(kv[0]).c_str(), "]"); changed |= traverse(ctx, kv[1]); } + } else if (tp.is(coop_vector_type)) { + CoopVec *vec = nb::cast(h, false); + uint32_t idx = vec->m_index; + size_t size = size_valid(v, ctx.label, h, vec->m_size); + + if (new_variable) { + v->index_orig = ad_var_inc_ref(idx); + v->index = ad_var_inc_ref(idx); + } else { + changed = idx != v->index; + if (changed) { + uint64_t old = v->index; + v->index = ad_var_inc_ref(idx); + ad_var_dec_ref(old); + } + } + + if (!ctx.write && !changed && !new_variable) { + for (size_t i = 0; i < size; ++i) { + ScopedAppendLabel guard(ctx, "[", i, "]"); + changed |= traverse(ctx, state.find(ctx.label)->second.value); + } + } else { + nb::list l(h), r; + for (size_t i = 0; i < size; ++i) { + ScopedAppendLabel guard(ctx, "[", i, "]"); + changed |= traverse(ctx, l[i]); + } + if (ctx.write) { + *vec = CoopVec(l); + ad_var_inc_ref(vec->m_index); + ad_var_dec_ref(v->index); + v->index = vec->m_index; + } + } } else { nb::object traverse_cb = nb::getattr( h, ctx.write ? DR_STR(_traverse_1_cb_rw) : DR_STR(_traverse_1_cb_ro), @@ -631,7 +668,7 @@ void VariableTracker::verify_size(size_t size) { strcmp(jit_var_kind_name((uint32_t) v.index), "loop_phi") == 0) continue; - size_t size_2 = jit_var_size((uint32_t)v.index); + size_t size_2 = jit_var_size((uint32_t) v.index); if (size != size_2 && size != 1 && size_2 != 1 && !jit_var_is_dirty((uint32_t)v.index)) nb::raise("this operation processes arrays of size %zu, while " @@ -730,6 +767,11 @@ nb::object VariableTracker::Impl::restore(dr::string &label) { ScopedAppendLabel guard(label, "[", nb::repr(k).c_str(), "]"); d[k] = restore(label); } + } else if (tp.is(coop_vector_type)) { + CoopVec *vec = nb::cast(value, false); + ad_var_inc_ref(v->index_orig); + ad_var_dec_ref(vec->m_index); + vec->m_index = v->index_orig; } else { if (nb::dict ds = get_drjit_struct(tp); ds.is_valid()) { for (auto [k, _] : ds) { @@ -847,48 +889,58 @@ std::pair VariableTracker::Impl::rebuild(dr::string &label) { value = tmp; } } - } else { - if (nb::dict ds = get_drjit_struct(tp); ds.is_valid()) { - nb::object tmp = tp(); - for (auto [k, _] : ds) { - ScopedAppendLabel guard(label, ".", nb::str(k).c_str()); - auto [o, n] = rebuild(label); - nb::setattr(tmp, k, o); - new_object |= n; - } - if (new_object) { - if (mutate) { - for (nb::handle k : ds.keys()) - nb::setattr(value, k, nb::getattr(tmp, k)); - new_object = false; - } else { - value = tmp; - } - } - } else if (nb::object df = get_dataclass_fields(tp); df.is_valid()) { - nb::dict tmp; - for (auto field : df) { - nb::object k = field.attr(DR_STR(name)); - ScopedAppendLabel guard(label, ".", nb::str(k).c_str()); - auto [o, n] = rebuild(label); - tmp[k] = o; - new_object |= n; + } else if (tp.is(coop_vector_type)) { + size_t size = size_valid(v, label, value, nb::len(value)); + nb::list tmp; + + for (size_t i = 0; i < size; ++i) { + ScopedAppendLabel guard(label, "[", i, "]"); + auto [o, n] = rebuild(label); + tmp.append(o); + } + + value = nb::cast(CoopVec(tmp)); + new_object = true; + } else if (nb::dict ds = get_drjit_struct(tp); ds.is_valid()) { + nb::object tmp = tp(); + for (auto [k, _] : ds) { + ScopedAppendLabel guard(label, ".", nb::str(k).c_str()); + auto [o, n] = rebuild(label); + nb::setattr(tmp, k, o); + new_object |= n; + } + if (new_object) { + if (mutate) { + for (nb::handle k : ds.keys()) + nb::setattr(value, k, nb::getattr(tmp, k)); + new_object = false; + } else { + value = tmp; } - if (new_object) { - if (mutate) { - for (auto field : df) { - nb::object k = field.attr(DR_STR(name)); - nb::setattr(value, k, tmp[k]); - } - new_object = false; - } else { - value = tp(**tmp); + } + } else if (nb::object df = get_dataclass_fields(tp); df.is_valid()) { + nb::dict tmp; + for (auto field : df) { + nb::object k = field.attr(DR_STR(name)); + ScopedAppendLabel guard(label, ".", nb::str(k).c_str()); + auto [o, n] = rebuild(label); + tmp[k] = o; + new_object |= n; + } + if (new_object) { + if (mutate) { + for (auto field : df) { + nb::object k = field.attr(DR_STR(name)); + nb::setattr(value, k, tmp[k]); } + new_object = false; + } else { + value = tp(**tmp); } - } else if (!value.is(v->value)) { - value = v->value; - new_object = true; } + } else if (!value.is(v->value)) { + value = v->value; + new_object = true; } return { value, new_object }; diff --git a/tests/test_coop_vec.py b/tests/test_coop_vec.py new file mode 100644 index 00000000..395ba041 --- /dev/null +++ b/tests/test_coop_vec.py @@ -0,0 +1,640 @@ +import drjit as dr +import drjit.nn as nn +import pytest +import sys + +def skip_if_coopvec_not_supported(t): + if dr.backend_v(t) == dr.JitBackend.CUDA: + if dr.detail.cuda_version() < (12, 8): + pytest.skip("CUDA driver does not support cooperative vectors (Driver R570) or later is required") + +@pytest.test_arrays('jit,float16,shape=(3, *),-diff', 'jit,float32,shape=(3, *),-diff') +def test01_pack_unpack(t): + skip_if_coopvec_not_supported(t) + + # Test coop vector creation and unpacking + m = sys.modules[t.__module__] + v = dr.full(dr.value_t(t), 7, 32) + x = nn.CoopVec(t(1, 2, 3), t(4, 5, 6), v, 8) + assert len(x) == 8 + assert len(nn.CoopVec(*x, 2, (4, 5), *x)) == 19 + y = list(x) + z = m.ArrayXf(x) + assert len(y) == 8 and len(z) == 8 + result_ok = True + for i in range(8): + result_ok &= dr.all(y[i] == i+1) + result_ok &= dr.all(z[i] == i+1) + assert result_ok + + +@pytest.mark.parametrize('size', [0, 20, 10]) +@pytest.test_arrays('jit,float16,shape=(*),-diff', 'jit,float32,shape=(*),-diff') +def test02_add_sub(t, size): + skip_if_coopvec_not_supported(t) + + # Test addition and subtraction + x = nn.CoopVec(dr.full(t, 5, 32), 6, *tuple(range(size))) + y = x + 15 + z = y - 2 + r0, r1 = list(z)[0:2] + dr.schedule(r0, r1) + assert dr.all((r0 == 18) & (r1 == 19)) + +@pytest.mark.parametrize('size', [0, 20, 100]) +@pytest.test_arrays('jit,float16,shape=(*),-diff', 'jit,float32,shape=(*),-diff') +def test03_add_min_max_fma(t, size): + skip_if_coopvec_not_supported(t) + + # Test min/max/FMA operations + x = nn.CoopVec(t(5), 8, *tuple(range(size))) + x_min = dr.minimum(x, 6) + x_max = dr.maximum(x, 7) + # zero addition needed to work around a constant propagation bug in R570 driver.. + zero = dr.opaque(t, 0) + z = dr.fma(x_min, x_max, 1+zero) + r0, r1 = list(z)[0:2] + dr.schedule(r0, r1) + assert r0 == 36 and r1 == 49 + + +@pytest.mark.parametrize('sub_slice', [False, True]) +@pytest.test_arrays('jit,float16,shape=(*),-diff') +def test04_pack_unpack(t, sub_slice): + skip_if_coopvec_not_supported(t) + + # Test the nn.pack() and nn.unpack() memory operations + m = sys.modules[t.__module__] + extra = 2 if sub_slice else 0 + X = m.TensorXf16(dr.arange(t, 24*(32+extra)), (24, 32+extra)) + Xv = nn.view(X) + + assert Xv.dtype == dr.VarType.Float16 + assert Xv.offset == 0 + assert Xv.size == 24*(32+extra) + assert Xv.shape == (24, 32+extra) + assert Xv.stride == 32+extra + assert Xv.buffer is X.array + + Xv1 = Xv[0:16, 0:32] + Xv2 = Xv[16:, 0:32] + X1 = X[0:16, 0:32] + X2 = X[16:, 0:32] + + assert Xv1.dtype == dr.VarType.Float16 + assert Xv1.offset == 0 + assert Xv1.shape == (16, 32) + assert Xv1.stride == 32+extra + assert Xv1.size == (Xv1.shape[0] - 1) * Xv1.stride + Xv1.shape[1] + assert Xv1.buffer is X.array + + assert Xv2.dtype == dr.VarType.Float16 + assert Xv2.offset == 16*(32+extra) + assert Xv2.size == (Xv2.shape[0] - 1) * Xv2.stride + Xv2.shape[1] + assert Xv2.shape == (8, 32) + assert Xv2.stride == 32+extra + assert Xv2.buffer is X.array + + for i in range(2): + _, *Pa = nn.pack( + Xv1, Xv2, + layout='inference' if i == 0 else 'training' + ) + + _, X1a, X2a = nn.unpack(*Pa) + assert dr.all(m.TensorXf16(X1a) == X1[:, 0:32], axis=None) + assert dr.all(m.TensorXf16(X2a) == X2[:, 0:32], axis=None) + + +@pytest.mark.parametrize('shape', [(2, 8), (5, 2), (16, 16)]) +@pytest.mark.parametrize('transpose', [False, True]) +@pytest.mark.parametrize('bias', [False, True]) +@pytest.mark.parametrize('pack', [False, True]) +@pytest.test_arrays('jit,tensor,float16,-diff', 'jit,tensor,float32,-diff') +def test05_matvec(t, shape, transpose, bias, pack): + skip_if_coopvec_not_supported(t) + + # Test matrix multiplication for various sizes and configurations (primal) + m = sys.modules[t.__module__] + Tensor = t + Float = dr.array_t(t) + + if dr.backend_v(t) == dr.JitBackend.CUDA: + if (not pack and shape[1] == 2) or \ + (not pack and transpose) or \ + dr.type_v(t) == dr.VarType.Float32: + pytest.skip("Unsupported configuration") + + output_size = shape[1] if transpose else shape[0] + input_size = shape[0] if transpose else shape[1] + + A = Tensor(m.PCG32(dr.prod(shape), 1).next_float_normal(Float), shape) + A_n = A.numpy() + + if bias: + b = Tensor(m.PCG32(output_size, 2).next_float_normal(Float)) + b_n = b.numpy() + else: + b = b_n = None + + if pack: + if bias: + _, A, b = nn.pack(A, b) + assert A.buffer is b.buffer + else: + _, A = nn.pack(A) + else: + A = nn.view(A) + if bias: + b = nn.view(b) + + rng_3 = m.PCG32(32, 3) + x = [rng_3.next_float_normal(Float) for _ in range(input_size)] + x_n = Tensor(x).numpy() + + x = nn.CoopVec(x) + r = nn.matvec(A, x, b, transpose=transpose) + r_n = Tensor(r).numpy() + + if transpose: + A_n = A_n.T + ref = A_n @ x_n + + if bias: + ref += b_n[:, None] + + assert dr.allclose(r_n, ref) + + +@pytest.test_arrays('jit,shape=(*),float16,-diff', 'jit,shape=(*),float32,-diff') +@pytest.mark.parametrize('op', ['exp2', 'log2', 'tanh']) +def test06_unary(t, op): + skip_if_coopvec_not_supported(t) + + # Test some special unary operations that are supported by coop vectors + func = getattr(dr, op) + x = nn.CoopVec(t(0.1), t(0.2), t(0.3)) + r = func(x) + x, y, z = r + dr.schedule(x, y, z) + assert dr.allclose(x[0], func(0.1), rtol=1e-3) + assert dr.allclose(y[0], func(0.2), rtol=1e-3) + assert dr.allclose(z[0], func(0.3), rtol=1e-3) + + +@pytest.test_arrays('jit,shape=(*),float16,-diff', 'jit,shape=(*),float32,-diff') +def test07_step(t): + skip_if_coopvec_not_supported(t) + + # Test the dr.step() function on coop vectors + x = nn.CoopVec(t(0.1), t(0.2)) + y = nn.CoopVec(t(0.15), t(0.15)) + z = dr.step(x, y) + r0, r1 = z + dr.schedule(r0, r1) + assert r0 == 0 and r1 == 1 + + +@pytest.test_arrays('jit,shape=(*),float16,diff', 'jit,shape=(*),float32,diff') +def test08_fwd_grad_unpack(t): + skip_if_coopvec_not_supported(t) + + # Test that forward gradients correctly propagate through coop vector creation and unpacking + a, b = t(1), t(2) + dr.enable_grad(a, b) + z = nn.CoopVec(a, b) # pack + assert dr.grad_enabled(z) + assert not dr.grad_enabled(dr.detach(z)) + x, y = z # unpack + a.grad = 4 + b.grad = 5 + dr.forward_to(x, y) + dr.schedule(x.grad, y.grad) + assert x.grad == 4 + assert y.grad == 5 + assert dr.grad_enabled(z) + dr.disable_grad(z) + assert not dr.grad_enabled(z) + + +@pytest.test_arrays('jit,shape=(*),float16,diff', 'jit,shape=(*),float32,diff') +def test09_bwd_grad_unpack(t): + skip_if_coopvec_not_supported(t) + + # Test that backward gradients correctly propagate through coop vector creation and unpacking + a, b = t(1), t(2) + dr.enable_grad(a, b) + z = nn.CoopVec(a, b) # pack + x, y = z # unpack + x.grad = 4 + y.grad = 5 + dr.backward_to(a, b) + dr.schedule(a.grad, b.grad) + assert a.grad == 4 + assert b.grad == 5 + + +@pytest.test_arrays('jit,shape=(*),float16,diff', 'jit,shape=(*),float32,diff') +def test10_fwd_addition(t): + skip_if_coopvec_not_supported(t) + + # Propagate forward gradients through an addition + a, b = t(1), t(1) + c, d = t(1), t(1) + dr.enable_grad(a, b, c, d) + x0 = nn.CoopVec(a, b) + x1 = nn.CoopVec(c, d) + x2 = x0 + x1 + r0, r1 = x2 + a.grad = 1 + b.grad = 2 + c.grad = 100 + d.grad = 200 + dr.forward_to(r0, r1) + dr.schedule(r0.grad, r1.grad) + assert r0.grad == 101 and r1.grad == 202 + + +@pytest.test_arrays('jit,shape=(*),float16,diff', 'jit,shape=(*),float32,diff') +def test11_bwd_mul(t): + skip_if_coopvec_not_supported(t) + + # Propagate forward gradients through a multiplication + a, b = t(8), t(9) + c, d = t(3), t(2) + dr.enable_grad(a, b, c, d) + x0 = nn.CoopVec(a, b) + x1 = nn.CoopVec(c, d) + x2 = x0 * x1 + r0, r1 = x2 + r0.grad = 1 + r1.grad = 10 + dr.backward_to(a, b, c, d) + dr.schedule(a.grad, b.grad, c.grad, d.grad) + assert a.grad == 3 and b.grad == 20 + assert c.grad == 8 and d.grad == 90 + + +@pytest.test_arrays('jit,shape=(*),float16,diff', 'jit,shape=(*),float32,diff') +def test12_bwd_min_max_fma(t): + skip_if_coopvec_not_supported(t) + + # Check derivatives of supported binary/ternary operations + x = [ t(1), t(2), t(3), t(4) ] + y = t(5) + z = t(6) + minval = t(25) + maxval = t(12) + dr.enable_grad(x, y, z, minval, maxval) + q = nn.CoopVec(x) + + q = dr.fma(q, y, z) + q = dr.minimum(q, minval) + q = dr.maximum(q, maxval) + + a, b, c, d = q + dr.backward_from(a+b*2 + c*4 + d*8) + dr.schedule(x[0].grad, x[1].grad, x[2].grad, x[3].grad, y.grad, + z.grad, minval.grad, maxval.grad, a, b, c, d) + assert a[0] == 12 and b[0] == 16 and c[0] == 21 and d[0] == 25 + assert x[0].grad[0] == 0 and x[1].grad[0] == 10 and x[2].grad[0] == 20 and x[3].grad[0] == 0 + assert minval.grad[0] == 8 and maxval.grad[0] == 1 + +@pytest.test_arrays('jit,shape=(*),float16,diff', 'jit,shape=(*),float32,diff') +def test13_exp2_tanh_fwd(t): + skip_if_coopvec_not_supported(t) + + # Check derivatives of supported unary transcendental operations + x = t(2) + dr.enable_grad(x) + y = nn.CoopVec(x) + r0 = dr.exp2(y) + r1 = dr.tanh(y) + r0, = r0; r1, = r1 + dr.forward_from(x) + dr.schedule(r0, r1, r0.grad, r1.grad) + assert dr.allclose(r0[0], 4) + assert dr.allclose(r1[0], 0.9640275800758169, rtol=1e-3) + assert dr.allclose(r0.grad[0], 2.77259, rtol=1e-3) + assert dr.allclose(r1.grad[0], 0.0706508, rtol=1e-2) + + +@pytest.mark.parametrize('transpose', [False, True]) +@pytest.mark.parametrize('has_A_grad', [False, True]) +@pytest.mark.parametrize('has_x_grad', [False, True]) +@pytest.mark.parametrize('has_b_grad', [None, False, True]) +@pytest.mark.parametrize('layout', ['training', 'inference']) +@pytest.test_arrays('jit,tensor,float16,diff') +def test14_matvec_fwd(t, transpose, has_A_grad, has_x_grad, has_b_grad, layout): + skip_if_coopvec_not_supported(t) + + # Test forward-propagation of derivatives from input through matrix multiplication + m = sys.modules[t.__module__] + Tensor = t + Float = dr.array_t(t) + Matrix2f = m.Matrix2f16 + Array2f = m.Array2f16 + + if not has_A_grad and not has_x_grad and not has_b_grad: + pytest.skip("Trivial configuration") + if dr.backend_v(Float) == dr.JitBackend.LLVM and layout == 'training': + pytest.skip("Layout not used in LLVM backend") + + # Set up 'A' matrix + A = [[4, 2], [5, 1]] + A_grad = [[2, 1], [1, -1]] + _, A_v = nn.pack(Tensor(A), layout=layout) + A_ref = Matrix2f(A) + if has_A_grad: + _, A_grad_v = nn.pack(Tensor(A_grad)) + assert not dr.grad_enabled(A_v) + dr.enable_grad(A_v) + assert dr.grad_enabled(A_v) + assert not dr.grad_enabled(dr.detach(A_v)) + A_v.buffer.grad = A_grad_v.buffer + dr.enable_grad(A_ref) + dr.set_grad(A_ref, A_grad) + + # Set up 'x' vector + x = Array2f(1, 2) + if has_x_grad: + dr.enable_grad(x) + x.grad = [2, 1] + x_v = nn.CoopVec(x) + + # Set up 'b' vector + b_v = None + b_ref = Array2f(0) + if has_b_grad is not None: + b1, b2 = Float(-1), Float(1) + b_ref = Array2f(b1, b2) + _, b_v = nn.pack(Tensor([-1, 1])) + + if has_b_grad is True: + dr.enable_grad(b_ref) + b_ref.grad = [1, -1] + _, b_grad_v = nn.pack(Tensor([1, -1])) + dr.enable_grad(b_v.buffer) + b_v.buffer.grad = b_grad_v.buffer + + # Compute the reference + if transpose: + A_ref = A_ref.T + y_ref = A_ref @ x + b_ref + + y = Array2f(nn.matvec(A_v, x_v, b_v, transpose)) + dr.forward_to(y, y_ref) + dr.schedule(y, y.grad, y_ref, y_ref.grad) + + # print(f"primal: y={y} vs ref={y_ref}") + # print(f"grad: y={y.grad} vs ref={y_ref.grad}") + + assert dr.all((y == y_ref) & (y.grad == y_ref.grad)) + + +@pytest.mark.parametrize('transpose', [False, True]) +@pytest.test_arrays('jit,tensor,float16,-diff') +def test15_matvec_in_vcall(t, transpose): + skip_if_coopvec_not_supported(t) + + # Check that mat-vec products still work as expected when done from a callable + Float = dr.array_t(t) + UInt32 = dr.uint32_array_t(Float) + size = 64 + A = dr.normal(t, (size, size)) + b = dr.normal(t, size) + _, A, b = nn.pack(A, b) + + def mult_it(): + x = nn.CoopVec( + Float(i/(size-1) - 0.5) for i in range(size) + ) + return list(nn.matvec(A, x, b, transpose=transpose))[0] + + r0 = mult_it() + r1 = dr.switch(UInt32(0), [mult_it]) + + dr.schedule(r0, r1) + assert dr.allclose(r0[0], r1[0]) + + # Try again without bias vector + b = None + + r0 = mult_it() + r1 = dr.switch(UInt32(0), [mult_it]) + + dr.schedule(r0, r1) + assert r0[0] == r1[0] + + +@pytest.mark.parametrize('in_vcall', [False, True]) +@pytest.test_arrays('jit,tensor,float16,diff') +def test16_matvec_bwd(t, in_vcall): + skip_if_coopvec_not_supported(t) + + # Test the reverse-mode derivative of a matrix-vector product + # (potentially in a vcall) + + m = sys.modules[t.__module__] + UInt32 = m.UInt32 + A = t([[1, 3], [-2, 4], [3, -2]]) + b = t([0, 0, 0]) + buffer, Av, bv = nn.pack(A, b, layout='training') + x = m.Array2f16(2, 4) + dr.enable_grad(x, buffer) + + def do_mul(x): + xv = nn.CoopVec(x) + yv = nn.matvec(Av, xv, bv) + return m.Array3f16(yv) + + if in_vcall: + y = dr.switch(UInt32(0), [do_mul], x) + else: + y = do_mul(x) + + z = dr.opaque(dr.array_t(t), 0) + + y.grad = (-2+z, 5+z, 10+z) + dr.backward_from(y) + grad_x = x.grad + + # print(f"{y=}") + # print(f"{grad_x=}") + + grad_x_ref = m.Array2f16(18, -6) + assert dr.all(grad_x_ref == grad_x) + + dr.schedule(grad_x) + _, grad_A = nn.unpack(Av.grad) + _, grad_b = nn.unpack(bv.grad) + + grad_A = t(grad_A) + grad_b = t(grad_b)[:, 0] + grad_A_ref = t([[-4, -8], [10, 20], [20, 40]]) + grad_b_ref = t([-2, 5, 10]) + assert dr.all(grad_A_ref == grad_A) + assert dr.all(grad_b_ref == grad_b) + + +@pytest.test_arrays('jit,shape=(*),float16,diff') +def test17_cast(t): + skip_if_coopvec_not_supported(t) + + z = dr.opaque(t, 0) + a = nn.CoopVec( + z + 1, + z + 2, + z + 3 + ) + b = nn.cast(a, dr.float32_array_t(t)) + c = nn.cast(b, dr.float16_array_t(t)) + x, y, z = c + dr.eval(x, y, z) + assert x[0] == 1 and y[0] == 2 and z[0] == 3 + + + +@pytest.test_arrays('jit,shape=(*),float32,-diff') +@dr.syntax +def test18_symbolic_loop_if_stmt(t): + skip_if_coopvec_not_supported(t) + + # Test that cooperative vectors can be passed through + # symbolic loops and conditionals + UInt32 = dr.uint32_array_t(t) + a = nn.CoopVec(t(1), t(2)) + i = UInt32(0) + + while i < 10: + if i > 5: + a += a + i += 1 + + x, y = a + dr.schedule(x, y, i) + assert x[0] == 16 and y[0] == 32 + + +@pytest.test_arrays('jit,shape=(*),float32,-diff') +@dr.syntax +def test19_no_eval(t): + skip_if_coopvec_not_supported(t) + + # Cooperative vectors cannot be evaluted via dr.eval() + UInt32 = dr.uint32_array_t(t) + a = nn.CoopVec(t(1), t(2)) + with pytest.raises(RuntimeError, match="Cooperative vectors cannot be evaluated"): + dr.schedule(a) + with pytest.raises(RuntimeError, match="Cooperative vectors cannot be evaluated"): + dr.eval(a) + with pytest.raises(RuntimeError, match="Cooperative vectors cannot be evaluated"): + dr.make_opaque(a) + + +@pytest.mark.parametrize('mode', ['evaluated', 'symbolic']) +@pytest.test_arrays('jit,shape=(*),float16,diff') +@dr.syntax +def test20_matvec_in_loop(t, mode): + # Check that derivative inference works when + # cooperative vectors are used inside loops + skip_if_coopvec_not_supported(t) + + m = sys.modules[t.__module__] + Float16 = t + TensorXf16 = m.TensorXf16 + Float32 = m.Float32 + UInt32 = m.UInt32 + + A = dr.ones(TensorXf16, shape=(2, 2)) + b = dr.zeros(Float16, shape=(2)) + + _, A_view, b_view = nn.pack(A, b, layout='inference') + + cnt = UInt32(0) + res = Float32(0) + + while dr.hint(cnt < 3, mode=mode): + x = nn.CoopVec(Float16([0.5]), Float16([0.5])) + a, b = nn.matvec(A_view, x, b_view) + res += Float32(a) + cnt += 1 + + assert res == 3 + + +@pytest.mark.parametrize('mode', ['evaluated', 'symbolic']) +@pytest.test_arrays('jit,shape=(*),float16,diff') +@dr.syntax +def test21_optimize_in_loop_bwd(t, mode): + # Check that derivative backpropagation occurs when + # cooperative vectors are used inside loops + skip_if_coopvec_not_supported(t) + + m = sys.modules[t.__module__] + Float16 = t + TensorXf16 = m.TensorXf16 + Float32 = m.Float32 + UInt32 = m.UInt32 + + A = dr.ones(TensorXf16, shape=(2, 2)) + b = dr.zeros(Float16, shape=(2)) + + buf, A_view, b_view = nn.pack(A, b, layout='training') + dr.enable_grad(buf) + + cnt = dr.zeros(UInt32, 2) + res = dr.zeros(Float32, 2) + + while dr.hint(cnt < 3, max_iterations=-1, mode=mode): + x = nn.CoopVec(Float16(0.5), Float16(0.5)) + a, _ = nn.matvec(A_view, x, b_view) + res += Float32(a) + cnt += 1 + + dr.backward(res) + + _, A_view, b_view = nn.unpack(A_view.grad, b_view.grad) + A = TensorXf16(A_view) + b = TensorXf16(b_view) + assert dr.all(A == TensorXf16([[3, 3], [0, 0]])) + assert dr.all(b == TensorXf16([[6], [0]])) + + +@pytest.mark.parametrize('mode', ['evaluated', 'symbolic']) +@pytest.test_arrays('jit,shape=(*),float16,diff') +@dr.syntax +def test22_optimize_in_loop_bwd_v2(t, mode): + # Check that derivative backpropagation occurs when + # cooperative vectors are used inside loops, and the + # backprop call is placed there as well + + skip_if_coopvec_not_supported(t) + + m = sys.modules[t.__module__] + Float16 = t + TensorXf16 = m.TensorXf16 + Float32 = m.Float32 + UInt32 = m.UInt32 + + A = dr.ones(TensorXf16, shape=(2, 2)) + b = dr.zeros(Float16, shape=(2)) + + buf, A_view, b_view = nn.pack(A, b, layout='training') + dr.enable_grad(buf) + + cnt = dr.zeros(UInt32, 2) + res = dr.zeros(Float32, 2) + + while dr.hint(cnt < 3, mode=mode, exclude=[A_view, b_view]): + x = nn.CoopVec(Float16(0.5), Float16(0.5)) + a, _ = nn.matvec(A_view, x, b_view) + res = Float32(a) + dr.backward(res) + cnt += 1 + + _, A_view, b_view = nn.unpack(A_view.grad, b_view.grad) + A = TensorXf16(A_view) + b = TensorXf16(b_view) + assert dr.all(A == TensorXf16([[3, 3], [0, 0]])) + assert dr.all(b == TensorXf16([[6], [0]])) diff --git a/tests/test_while_loop.py b/tests/test_while_loop.py index f6fc805a..342ada06 100644 --- a/tests/test_while_loop.py +++ b/tests/test_while_loop.py @@ -736,3 +736,4 @@ def test31_tensor_loop_preserve_shape(t, mode): assert a.shape == (10, 11) assert a.shape == (10, 11) + diff --git a/tests/test_while_loop_ad.py b/tests/test_while_loop_ad.py index bf12f339..3d0ea36e 100644 --- a/tests/test_while_loop_ad.py +++ b/tests/test_while_loop_ad.py @@ -64,9 +64,10 @@ def test03_sum_loop_fwd(t, mode): @pytest.mark.parametrize('mode', ['evaluated', 'symbolic']) +@pytest.mark.parametrize('make_copy', [True, False]) @pytest.test_arrays('float32,diff,shape=(*)') @dr.syntax -def test04_sum_loop_rev(t, mode): +def test04_sum_loop_rev(t, mode, make_copy): # Test the "sum loop" optimization (max_iterations=-1) for # consistency against test03 UInt32 = dr.uint32_array_t(t) @@ -74,8 +75,11 @@ def test04_sum_loop_rev(t, mode): y, i = Float(0), UInt32(0) x = dr.linspace(Float, .25, 1, 4) - xo = x dr.enable_grad(x) + if make_copy: + xo = Float(x) + else: + xo = x while dr.hint(i < 10, max_iterations=-1, mode=mode): y += x**i @@ -87,6 +91,7 @@ def test04_sum_loop_rev(t, mode): assert dr.allclose(y, [1.33333, 1.99805, 3.77475, 10]) assert dr.allclose(xo.grad, [1.77773, 3.95703, 12.0956, 45]) + @pytest.mark.parametrize('variant', ['fwd', 'bwd']) @pytest.test_arrays('float32,is_diff,shape=(*)') def test05_evaluated_ad_kernel_launch_count(t, variant): @@ -132,6 +137,7 @@ def test05_evaluated_ad_kernel_launch_count(t, variant): for k in h: assert k['operation_count'] < iterations + @pytest.mark.parametrize('variant', [0, 1]) @pytest.mark.parametrize('mode', ['evaluated', 'symbolic']) @pytest.test_arrays('float32,diff,shape=(*)') @@ -240,3 +246,20 @@ def loop(l: list, t, mode): dr.backward(loss) + +@pytest.mark.parametrize('mode', ['evaluated', 'symbolic']) +@pytest.test_arrays('float32,is_diff,shape=(*)') +@dr.syntax +def test32_simple_loop(t, mode): + # Testcase for simple backwards derivatives with gathers + i = dr.uint32_array_t(t)(0) + x = dr.ones(t, 10) + q = dr.zeros(t) + dr.enable_grad(x, 10) + + while dr.hint(i < 10, max_iterations=-1, mode=mode): + q += dr.gather(t, x, i) + i += 1 + + dr.backward(q) + assert dr.all(x.grad == [1]*10)