Skip to content

Enhancing DifferentiateWith Interface #806

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
yebai opened this issue May 26, 2025 · 7 comments
Open

Enhancing DifferentiateWith Interface #806

yebai opened this issue May 26, 2025 · 7 comments
Labels
core Related to the core utilities of the package

Comments

@yebai
Copy link
Contributor

yebai commented May 26, 2025

Current Interface

The existing DifferentiateWith(f, backend) interface in DifferentiationInterface.jl presents a significant limitation: it inherently supports only single-argument functions. This design makes it cumbersome to:

  • Differentiate functions with multiple arguments.
  • Pass additional context or non-differentiable arguments (constants, pre-allocated caches) to the differentiation backend.

Proposed Interface

To address these limitations, we propose a more expressive interface for DifferentiateWith:

Tfunc_sig = Tuple{typeof(f), T_arg1, T_arg2, ..., T_argN}
DifferentiateWith(Tfunc_sig, backend_to_use::AbstractADType)

Where Tfunc_sig represents the function signature. The first element is the function f itself (or its type), and subsequent elements T_arg1, T_arg2, ..., T_argN represent the types of arguments to f.

Argument Type Wrappers:

To provide more context to the backend about how each argument should be treated, we can introduce wrapper types:

  • Default: Arguments are assumed to be "active" (i.e., to be differentiated with respect to).
  • Constant{T}: Indicates that an argument of type T is a constant and should not be differentiated.
  • Cache{T}: Signals that an argument of type T is a pre-allocated cache that the backend can utilise.

Example Usage:

Consider a function f(x, y, z, c) where x and y are active arguments, z is a constant, and c is a cache. The func_sig would be constructed as:

Targtypes = (typeof(x), typeof(y), Constant{typeof(z)}, Cache{typeof(c)})
Tfunc_sig = Tuple{typeof(f), Targtypes...}
# or more explicitly:
# Tfunc_sig = Tuple{typeof(f), typeof(x), typeof(y), Constant{typeof(z)}, Cache{typeof(c)}}

dw = DifferentiateWith(Tfunc_sig, backend)

Internal Handling:

With this richer Tfunc_sig, DifferentiateWith can internally manage functions with multiple arguments. For backends that fundamentally operate on single-argument functions (e.g., by packing arguments into a tuple), DifferentiateWith can perform this packing/unpacking automatically before invoking the backend's pushforward or pullback implementations. This keeps the backend APIs simpler while providing a user-friendly multi-argument interface.

@yebai yebai changed the title Relax DifferentiateWith Relax some DifferentiateWith constraints May 26, 2025
@yebai yebai changed the title Relax some DifferentiateWith constraints Enhancing DifferentiateWith Interface Jun 1, 2025
@gdalle gdalle added the core Related to the core utilities of the package label Jun 2, 2025
@gdalle
Copy link
Member

gdalle commented Jun 2, 2025

Thanks for the write up! Just to clarify, multiple active arguments are not supported by DI, nor will they be supported in the foreseeable future (#683). I agree with the need for contexts though, already suggested in #675.

As for the API you suggest, passing the signature isn't enough because f itself might contain data that we need to call and differentiate it. To me, a simpler interface like:

DifferentiateWith(f, backend, (Constant, Constant, Cache))

might suffice, where the third positional argument is the ordered tuple of wrappers that will be applied to each context. The default would be the empty tuple (), which corresponds to the current behavior. What do you think?

@yebai
Copy link
Contributor Author

yebai commented Jun 3, 2025

passing the signature isn't enough because f itself might contain data

It is a design choice. Not all backends support callable objects closing over active data. Also, for implementation simplicity and thorough testing purposes, it might be better to pass the type signature instead. The tradeoff is that the type signature approach would forbid callable objects with active data when using DifferentiateWith. Happy to follow your call here.

By the way, have you encountered any examples where callable objects with active data could be critical?

@gdalle
Copy link
Member

gdalle commented Jun 3, 2025

It's not just active data, it's any data. If we stick to the type signature, we can't differentiate any closure at all

@yebai
Copy link
Contributor Author

yebai commented Jun 3, 2025

It's not just active data, it's any data. If we stick to the type signature, we can't differentiate any closure at all

Do we need to differentiate closure? Users can always do _f(f) = f(), then DifferentiateWith(_f, ...), right?

@gdalle
Copy link
Member

gdalle commented Jun 3, 2025

Yes we do. I don't quite understand your suggestion but even there you're passing a function object, not a function type

@yebai
Copy link
Contributor Author

yebai commented Jun 3, 2025

It should be _f(f) = f(), then DifferentiateWith(typeof(_f), ...).

@gdalle
Copy link
Member

gdalle commented Jun 3, 2025

Regardless, you can't compute derivatives of closures if you only give type information. Consider the following closure:

julia> f = let a = 2
           x -> a * x
       end
#3 (generic function with 1 method)

julia> typeof(f)
var"#3#4"{Int64}

julia> dump(f)
#3 (function of type var"#3#4"{Int64})
  a: Int64 2

Its type does not tell us that it performs multiplication by 2, we need the actual object for that.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
core Related to the core utilities of the package
Projects
None yet
Development

No branches or pull requests

2 participants