-
Notifications
You must be signed in to change notification settings - Fork 48
Frozen Functions #336
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
DoeringChristian
wants to merge
3
commits into
master
Choose a base branch
from
frozen-functions
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Frozen Functions #336
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,12 +19,12 @@ | |
import sys as _sys | ||
if _sys.version_info < (3, 11): | ||
try: | ||
from typing_extensions import overload, Optional, Type, Tuple, List, Sequence, Union, Literal, Callable | ||
from typing_extensions import overload, Optional, Type, Tuple, List, Sequence, Union, Literal, Callable, TypeVar | ||
except ImportError: | ||
raise RuntimeError( | ||
"Dr.Jit requires the 'typing_extensions' package on Python <3.11") | ||
else: | ||
from typing import overload, Optional, Type, Tuple, List, Sequence, Union, Literal, Callable | ||
from typing import overload, Optional, Type, Tuple, List, Sequence, Union, Literal, Callable, TypeVar | ||
|
||
from .ast import syntax, hint | ||
from .interop import wrap | ||
|
@@ -2494,6 +2494,256 @@ def binary_search(start, end, pred): | |
|
||
return start | ||
|
||
# Represents the frozen function passed to the decorator without arguments | ||
F = TypeVar("F") | ||
# Represents the frozen function passed to the decorator with arguments | ||
F2 = TypeVar("F2") | ||
|
||
@overload | ||
def freeze( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Did you try rendering the documentation? (With |
||
f: None = None, | ||
*, | ||
state_fn: Optional[Callable], | ||
limit: Optional[int] = None, | ||
warn_after: int = 10, | ||
backend: Optional[JitBackend] = None, | ||
) -> Callable[[F], F]: | ||
""" | ||
Decorator to "freeze" functions, which improves efficiency by removing | ||
repeated JIT tracing overheads. | ||
|
||
In general, Dr.Jit traces computation and then compiles and launches kernels | ||
containing this trace (see the section on :ref:`evaluation <eval>` for | ||
details). While the compilation step can often be skipped via caching, the | ||
tracing cost can still be significant especially when repeatedly evaluating | ||
complex models, e.g., as part of an optimization loop. | ||
|
||
The :py:func:`@dr.freeze <drjit.freeze>` decorator adresses this problem by | ||
altogether removing the need to trace repeatedly. For example, consider the | ||
following decorated function: | ||
|
||
.. code-block:: python | ||
|
||
@dr.freeze | ||
def f(x, y, z): | ||
return ... # Complicated code involving the arguments | ||
|
||
Dr.Jit will trace the first call to the decorated function ``f()``, while | ||
collecting additional information regarding the nature of the function's inputs | ||
and regarding the CPU/GPU kernel launches representing the body of ``f()``. | ||
|
||
If the function is subsequently called with *compatible* arguments (more on | ||
this below), it will immediately launch the previously made CPU/GPU kernels | ||
without re-tracing, which can substantially improve performance. | ||
|
||
When :py:func:`@dr.freeze <drjit.freeze>` detects *incompatibilities* (e.g., ``x`` | ||
having a different type compared to the previous call), it will conservatively | ||
re-trace the body and keep track of another potential input configuration. | ||
|
||
Frozen functions support arbitrary :ref:`PyTrees <pytrees>` as function | ||
arguments and return values. | ||
|
||
The following may trigger re-tracing: | ||
|
||
- Changes in the **type** of an argument or :ref:`PyTree <pytrees>` element. | ||
- Changes in the **length** of a container (``list``, ``tuple``, ``dict``). | ||
- Changes of **dictionary keys** or **field names** of dataclasses. | ||
- Changes in the AD status (:py:`dr.grad_enabled() <drjit.grad_enabled>`) of a variable. | ||
- Changes of (non-PyTree) **Python objects**, as detected by mismatching ``hash()``. | ||
|
||
The following more technical conditions also trigger re-tracing: | ||
- A Dr.Jit variable changes from/to a **scalar** configuration (size ``1``). | ||
- The sets of variables of the same size change. In the example above, this | ||
would be the case if ``len(x) == len(y)`` in one call, and ``len(x) != len(y)`` | ||
subsequently. | ||
- When Dr.Jit variables reference external memory (e.g. mapped NumPy arrays), the | ||
memory can be aligned or unaligned. A re-tracing step is needed when this | ||
status changes. | ||
|
||
These all correspond to situations where the generated kernel code may need to | ||
change, and the system conservatively re-traces to ensure correctness. | ||
|
||
Frozen functions support arguments with a different variable *width* (see | ||
:py:func:`dr.with() <drjit.width>`) without re-tracing, as long as the sets of | ||
variables of the same width stay consistent. | ||
|
||
Some constructions are problematic and should be avoided in frozen functions. | ||
|
||
- The function :py:func:`dr.width() <drjit.width>` returns an integer literal | ||
that may be merged into the generated code. If the frozen function is later | ||
rerun with differently-sized arguments, the executed kernels will still | ||
reference the old size. One exception to this rule are constructions like | ||
`dr.arange(UInt32, dr.width(a))`, where the result only implicitly depends on | ||
the width value. | ||
|
||
**Advanced features**. The :py:func:`@dr.freeze <drjit.freeze>` decorator takes | ||
several optional parameters that are helpful in certain situations. | ||
|
||
- **Warning when re-tracing happens too often**: Incompatible arguments trigger | ||
re-tracing, which can mask issues where *accidentally* incompatible arguments | ||
keep :py:func:`@dr.freeze <drjit.freeze>` from producing the expected | ||
performance benefits. | ||
|
||
In such situations, it can be helpful to warn and identify changing | ||
parameters by name. This feature is enabled and set to ``10`` by default. | ||
|
||
.. code-block:: pycon | ||
|
||
>>> @dr.freeze(warn_after=1) | ||
>>> def f(x): | ||
... return x | ||
... | ||
>>> f(Int(1)) | ||
>>> f(Float(1)) | ||
The frozen function has been recorded 2 times, this indicates a problem | ||
with how the frozen function is being called. For example, calling it | ||
with changing python values such as an index. For more information about | ||
which variables changed set the log level to ``LogLevel::Debug``. | ||
|
||
- **Limiting memory usage**. Storing kernels for many possible input | ||
configuration requires device memory, which can become problematic. Set the | ||
``limit=`` parameter to enable a LRU cache. This is useful when calls to a | ||
function are mostly compatible but require occasional re-tracing. | ||
|
||
Args: | ||
limit (Optional[int]): An optional integer specifying the maximum number of | ||
stored configurations. Once this limit is reached, incompatible calls | ||
requiring re-tracing will cause the last used configuration to be dropped. | ||
|
||
warn_after (int): When the number of re-tracing steps exceeds this value, | ||
Dr.Jit will generate a warning that explains which variables changed | ||
between calls to the function. | ||
|
||
state_fn (Optional[Callable]): This optional callable can specify additional | ||
state to identifies the configuration. ``state_fn`` will be called with | ||
the same arguments as that of the decorated function. It should return a | ||
traversable object (e.g., a list or tuple) that is conceptually treated | ||
as if it was another input of the function. | ||
|
||
backend (Optional[JitBackend]): If no inputs are given when calling the | ||
frozen function, the backend used has to be specified using this argument. | ||
It must match the backend used for computation within the function. | ||
""" | ||
|
||
|
||
@overload | ||
def freeze( | ||
f: F, | ||
*, | ||
state_fn: Optional[Callable] = None, | ||
limit: Optional[int] = None, | ||
warn_after: int = 10, | ||
backend: Optional[JitBackend] = None, | ||
) -> F: ... | ||
|
||
|
||
def freeze( | ||
f: Optional[F] = None, | ||
*, | ||
state_fn: Optional[Callable] = None, | ||
limit: Optional[int] = None, | ||
warn_after: int = 10, | ||
backend: Optional[JitBackend] = None, | ||
) -> Union[F, Callable[[F2], F2]]: | ||
limit = limit if limit is not None else -1 | ||
backend = backend if backend is not None else JitBackend.Invalid | ||
|
||
def decorator(f): | ||
""" | ||
Internal decorator, returned in ``dr.freeze`` was used with arguments. | ||
""" | ||
import functools | ||
import inspect | ||
|
||
def inner(closure, *args, **kwargs): | ||
""" | ||
This inner function is the one that gets actually frozen. It receives | ||
any additional state such as closures or state specified with the | ||
``state`` lambda, and allows for traversal of it. | ||
""" | ||
return f(*args, **kwargs) | ||
|
||
class FrozenFunction: | ||
def __init__(self, f) -> None: | ||
closure = inspect.getclosurevars(f) | ||
self.closure = (closure.nonlocals, closure.globals) | ||
self.frozen = detail.FrozenFunction( | ||
inner, | ||
limit, | ||
warn_after, | ||
backend, | ||
) | ||
|
||
def __call__(self, *args, **kwargs): | ||
_state = state_fn(*args, **kwargs) if state_fn is not None else None | ||
return self.frozen([self.closure, _state], *args, **kwargs) | ||
|
||
@property | ||
def n_recordings(self): | ||
DoeringChristian marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
Represents the number of times the function was recorded. This | ||
includes occasions where it was recorded due to a dry-run failing. | ||
It does not necessarily correspond to the number of recordings | ||
currently cached see ``n_cached_recordings`` for that. | ||
""" | ||
return self.frozen.n_recordings | ||
|
||
@property | ||
def n_cached_recordings(self): | ||
DoeringChristian marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
Represents the number of recordings currently cached of the frozen | ||
function. If a recording fails in dry-run mode, it will not create | ||
a new recording, but replace the recording that was attemted to be | ||
replayed. The number of recordings can also be limited with | ||
the ``max_cache_size`` argument. | ||
""" | ||
return self.frozen.n_cached_recordings | ||
|
||
def clear(self): | ||
""" | ||
Clears the recordings of the frozen function, and resets the | ||
``n_recordings`` counter. The reference to the function is still | ||
kept, and the frozen function can be called again to re-trace | ||
new recordings. | ||
""" | ||
return self.frozen.clear() | ||
|
||
def __get__(self, obj, type=None): | ||
if obj is None: | ||
return self | ||
else: | ||
return FrozenMethod(self.frozen, self.closure, obj) | ||
|
||
class FrozenMethod(FrozenFunction): | ||
""" | ||
A FrozenMethod currying the object into the __call__ method. | ||
|
||
If the ``freeze`` decorator is applied to a method of some class, it has | ||
to call the internal frozen function with the ``self`` argument. To this | ||
end we implement the ``__get__`` method of the frozen function, to | ||
return a ``FrozenMethod``, which holds a reference to the object. | ||
The ``__call__`` method of the ``FrozenMethod`` then supplies the object | ||
in addition to the arguments to the internal function. | ||
""" | ||
def __init__(self, frozen, closure, obj) -> None: | ||
self.obj = obj | ||
self.frozen = frozen | ||
self.closure = closure | ||
|
||
def __call__(self, *args, **kwargs): | ||
_state = state_fn(self.obj, *args, **kwargs) if state_fn is not None else None | ||
return self.frozen([self.closure, _state], self.obj, *args, **kwargs) | ||
|
||
return functools.wraps(f)(FrozenFunction(f)) | ||
|
||
if f is not None: | ||
return decorator(f) | ||
else: | ||
return decorator | ||
|
||
|
||
del F | ||
del F2 | ||
|
||
def assert_true( | ||
cond, | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Submodule drjit-core
updated
10 files
+5 −0 | include/drjit-core/array.h | |
+6 −3 | include/drjit-core/jit.h | |
+22 −0 | src/api.cpp | |
+1 −0 | src/init.cpp | |
+7 −0 | src/internal.h | |
+87 −16 | src/record_ts.cpp | |
+6 −0 | src/record_ts.h | |
+1 −1 | src/util.h | |
+2 −0 | src/var.cpp | |
+22 −0 | tests/record.cpp |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.