Skip to content

Frozen Functions #336

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
254 changes: 252 additions & 2 deletions drjit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@
import sys as _sys
if _sys.version_info < (3, 11):
try:
from typing_extensions import overload, Optional, Type, Tuple, List, Sequence, Union, Literal, Callable
from typing_extensions import overload, Optional, Type, Tuple, List, Sequence, Union, Literal, Callable, TypeVar
except ImportError:
raise RuntimeError(
"Dr.Jit requires the 'typing_extensions' package on Python <3.11")
else:
from typing import overload, Optional, Type, Tuple, List, Sequence, Union, Literal, Callable
from typing import overload, Optional, Type, Tuple, List, Sequence, Union, Literal, Callable, TypeVar

from .ast import syntax, hint
from .interop import wrap
Expand Down Expand Up @@ -2494,6 +2494,256 @@ def binary_search(start, end, pred):

return start

# Represents the frozen function passed to the decorator without arguments
F = TypeVar("F")
# Represents the frozen function passed to the decorator with arguments
F2 = TypeVar("F2")

@overload
def freeze(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you try rendering the documentation? (With python -m sphinx docs html and then open html/index.html). I think that you will need to reference any additions in docs/reference.rst with some auto* command.

f: None = None,
*,
state_fn: Optional[Callable],
limit: Optional[int] = None,
warn_after: int = 10,
backend: Optional[JitBackend] = None,
) -> Callable[[F], F]:
"""
Decorator to "freeze" functions, which improves efficiency by removing
repeated JIT tracing overheads.

In general, Dr.Jit traces computation and then compiles and launches kernels
containing this trace (see the section on :ref:`evaluation <eval>` for
details). While the compilation step can often be skipped via caching, the
tracing cost can still be significant especially when repeatedly evaluating
complex models, e.g., as part of an optimization loop.

The :py:func:`@dr.freeze <drjit.freeze>` decorator adresses this problem by
altogether removing the need to trace repeatedly. For example, consider the
following decorated function:

.. code-block:: python

@dr.freeze
def f(x, y, z):
return ... # Complicated code involving the arguments

Dr.Jit will trace the first call to the decorated function ``f()``, while
collecting additional information regarding the nature of the function's inputs
and regarding the CPU/GPU kernel launches representing the body of ``f()``.

If the function is subsequently called with *compatible* arguments (more on
this below), it will immediately launch the previously made CPU/GPU kernels
without re-tracing, which can substantially improve performance.

When :py:func:`@dr.freeze <drjit.freeze>` detects *incompatibilities* (e.g., ``x``
having a different type compared to the previous call), it will conservatively
re-trace the body and keep track of another potential input configuration.

Frozen functions support arbitrary :ref:`PyTrees <pytrees>` as function
arguments and return values.

The following may trigger re-tracing:

- Changes in the **type** of an argument or :ref:`PyTree <pytrees>` element.
- Changes in the **length** of a container (``list``, ``tuple``, ``dict``).
- Changes of **dictionary keys** or **field names** of dataclasses.
- Changes in the AD status (:py:`dr.grad_enabled() <drjit.grad_enabled>`) of a variable.
- Changes of (non-PyTree) **Python objects**, as detected by mismatching ``hash()``.

The following more technical conditions also trigger re-tracing:
- A Dr.Jit variable changes from/to a **scalar** configuration (size ``1``).
- The sets of variables of the same size change. In the example above, this
would be the case if ``len(x) == len(y)`` in one call, and ``len(x) != len(y)``
subsequently.
- When Dr.Jit variables reference external memory (e.g. mapped NumPy arrays), the
memory can be aligned or unaligned. A re-tracing step is needed when this
status changes.

These all correspond to situations where the generated kernel code may need to
change, and the system conservatively re-traces to ensure correctness.

Frozen functions support arguments with a different variable *width* (see
:py:func:`dr.with() <drjit.width>`) without re-tracing, as long as the sets of
variables of the same width stay consistent.

Some constructions are problematic and should be avoided in frozen functions.

- The function :py:func:`dr.width() <drjit.width>` returns an integer literal
that may be merged into the generated code. If the frozen function is later
rerun with differently-sized arguments, the executed kernels will still
reference the old size. One exception to this rule are constructions like
`dr.arange(UInt32, dr.width(a))`, where the result only implicitly depends on
the width value.

**Advanced features**. The :py:func:`@dr.freeze <drjit.freeze>` decorator takes
several optional parameters that are helpful in certain situations.

- **Warning when re-tracing happens too often**: Incompatible arguments trigger
re-tracing, which can mask issues where *accidentally* incompatible arguments
keep :py:func:`@dr.freeze <drjit.freeze>` from producing the expected
performance benefits.

In such situations, it can be helpful to warn and identify changing
parameters by name. This feature is enabled and set to ``10`` by default.

.. code-block:: pycon

>>> @dr.freeze(warn_after=1)
>>> def f(x):
... return x
...
>>> f(Int(1))
>>> f(Float(1))
The frozen function has been recorded 2 times, this indicates a problem
with how the frozen function is being called. For example, calling it
with changing python values such as an index. For more information about
which variables changed set the log level to ``LogLevel::Debug``.

- **Limiting memory usage**. Storing kernels for many possible input
configuration requires device memory, which can become problematic. Set the
``limit=`` parameter to enable a LRU cache. This is useful when calls to a
function are mostly compatible but require occasional re-tracing.

Args:
limit (Optional[int]): An optional integer specifying the maximum number of
stored configurations. Once this limit is reached, incompatible calls
requiring re-tracing will cause the last used configuration to be dropped.

warn_after (int): When the number of re-tracing steps exceeds this value,
Dr.Jit will generate a warning that explains which variables changed
between calls to the function.

state_fn (Optional[Callable]): This optional callable can specify additional
state to identifies the configuration. ``state_fn`` will be called with
the same arguments as that of the decorated function. It should return a
traversable object (e.g., a list or tuple) that is conceptually treated
as if it was another input of the function.

backend (Optional[JitBackend]): If no inputs are given when calling the
frozen function, the backend used has to be specified using this argument.
It must match the backend used for computation within the function.
"""


@overload
def freeze(
f: F,
*,
state_fn: Optional[Callable] = None,
limit: Optional[int] = None,
warn_after: int = 10,
backend: Optional[JitBackend] = None,
) -> F: ...


def freeze(
f: Optional[F] = None,
*,
state_fn: Optional[Callable] = None,
limit: Optional[int] = None,
warn_after: int = 10,
backend: Optional[JitBackend] = None,
) -> Union[F, Callable[[F2], F2]]:
limit = limit if limit is not None else -1
backend = backend if backend is not None else JitBackend.Invalid

def decorator(f):
"""
Internal decorator, returned in ``dr.freeze`` was used with arguments.
"""
import functools
import inspect

def inner(closure, *args, **kwargs):
"""
This inner function is the one that gets actually frozen. It receives
any additional state such as closures or state specified with the
``state`` lambda, and allows for traversal of it.
"""
return f(*args, **kwargs)

class FrozenFunction:
def __init__(self, f) -> None:
closure = inspect.getclosurevars(f)
self.closure = (closure.nonlocals, closure.globals)
self.frozen = detail.FrozenFunction(
inner,
limit,
warn_after,
backend,
)

def __call__(self, *args, **kwargs):
_state = state_fn(*args, **kwargs) if state_fn is not None else None
return self.frozen([self.closure, _state], *args, **kwargs)

@property
def n_recordings(self):
"""
Represents the number of times the function was recorded. This
includes occasions where it was recorded due to a dry-run failing.
It does not necessarily correspond to the number of recordings
currently cached see ``n_cached_recordings`` for that.
"""
return self.frozen.n_recordings

@property
def n_cached_recordings(self):
"""
Represents the number of recordings currently cached of the frozen
function. If a recording fails in dry-run mode, it will not create
a new recording, but replace the recording that was attemted to be
replayed. The number of recordings can also be limited with
the ``max_cache_size`` argument.
"""
return self.frozen.n_cached_recordings

def clear(self):
"""
Clears the recordings of the frozen function, and resets the
``n_recordings`` counter. The reference to the function is still
kept, and the frozen function can be called again to re-trace
new recordings.
"""
return self.frozen.clear()

def __get__(self, obj, type=None):
if obj is None:
return self
else:
return FrozenMethod(self.frozen, self.closure, obj)

class FrozenMethod(FrozenFunction):
"""
A FrozenMethod currying the object into the __call__ method.

If the ``freeze`` decorator is applied to a method of some class, it has
to call the internal frozen function with the ``self`` argument. To this
end we implement the ``__get__`` method of the frozen function, to
return a ``FrozenMethod``, which holds a reference to the object.
The ``__call__`` method of the ``FrozenMethod`` then supplies the object
in addition to the arguments to the internal function.
"""
def __init__(self, frozen, closure, obj) -> None:
self.obj = obj
self.frozen = frozen
self.closure = closure

def __call__(self, *args, **kwargs):
_state = state_fn(self.obj, *args, **kwargs) if state_fn is not None else None
return self.frozen([self.closure, _state], self.obj, *args, **kwargs)

return functools.wraps(f)(FrozenFunction(f))

if f is not None:
return decorator(f)
else:
return decorator


del F
del F2

def assert_true(
cond,
Expand Down
16 changes: 14 additions & 2 deletions drjit/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,11 @@ class Optimizer(Generic[Extra], MutableMapping[str, dr.ArrayBase]):
# - an arbitrary sequence of additional optimizer-dependent state values
state: Dict[str, Tuple[dr.ArrayBase, Optional[LearningRate], Extra]]

DRJIT_STRUCT = {
"lr": LearningRate,
"state": dict,
}

def __init__(
self,
lr: LearningRate,
Expand Down Expand Up @@ -960,10 +965,15 @@ def _step(
# Compute the step size scale, which is a product of
# - EMA debiasing factor
# - Adaptive/parameter-specific scaling
Float32 = dr.float32_array_t(dr.leaf_t(grad))
Float64 = dr.float64_array_t(dr.leaf_t(grad))
ema_factor = Float32(
-dr.sqrt(1 - Float64(self.beta_2) ** t) / (1 - Float64(self.beta_1) ** t)
)
scale = cache.product(
dr.leaf_t(grad), # Desired type
lr,
-dr.sqrt(1 - self.beta_2**t) / (1 - self.beta_1**t),
ema_factor,
)

# Optional: use maximum of second order term
Expand All @@ -981,9 +991,11 @@ def _step(
def _reset(self, key: str, value: dr.ArrayBase, /) -> None:
valarr = value.array
tp = type(valarr)
UInt = dr.uint32_array_t(dr.leaf_t(tp))
t = UInt(0)
m_t = dr.opaque(tp, 0, valarr.shape)
v_t = dr.opaque(tp, 0, valarr.shape)
self.state[key] = value, None, (0, m_t, v_t)
self.state[key] = value, None, (t, m_t, v_t)

# Blend between the old and new versions of the optimizer extra state
def _select(
Expand Down
Loading