Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Subtypes of ADGradientWrapper are defined in package extension, thus unable to be dispatched on #32

Closed
sunxd3 opened this issue Jul 4, 2024 · 6 comments

Comments

@sunxd3
Copy link

sunxd3 commented Jul 4, 2024

There are use cases where users may want to dispatch on a particular subtype of ADGradientWrapper for customized behavior based on different AD backend.

One example of this is TuringLang/DynamicPPL.jl#626, where we want to be able to dispatch on ReverseDiffLogDensity, because we may need to recompute the cached tape. Now because ReverseDiffLogDensity is defined in package extension, it is invisible.

Would it make sense to move all (or some) of the definition of ADGradientWrapper subtypes into the main package to expose the name of the wrapper?

@yebai
Copy link

yebai commented Jul 5, 2024

@sunxd3 It is probably okay to open a PR to discuss this.

cc @devmotion

@devmotion
Copy link
Collaborator

Dispatching on these types in other packages seems quite brittle to me, regardless of whether they are available or not. These are considered to be internal types whose fields etc. might change at any point.

In the context of the DynamicPPL/Turing PR, it seems the safest option would be to just reconstruct the gradient wrapper with ADgradient + a new LogDensityFunction with the updated model - the AD type (+ options) should already be available to the sampler (Gibbs in that particular example) since it is used also to construct the additional AD-capable log density function? And the last point indicates that you should never need to dispatch on types such as ReverseDiffLogDensity since the AD struct (AutoReverseDiff etc.) contains all information?

@tpapp
Copy link
Owner

tpapp commented Jul 7, 2024

I did not look into the PR you linked, so I am not sure about the use case, but I agree with @devmotion: it would be very brittle to dispatch on internal types of this package, regardless of whether they are defined in an extension or not. The only symbol we export is ADgradient.

Nevertheless, if you need this functionality, we could add API for constructing a gradient from an existing one,

  1. replacing the log density function
  2. using another AD implementation

Whether we should use Base.similar for that or not is a matter of bikeshedding.

Would that help?

@torfjelde
Copy link
Contributor

Nevertheless, if you need this functionality, we could add API for constructing a gradient from an existing one,

That would indeed help a lot:)

And yeah, I agree with all the reasons for not dispatching on abstract gradient wrapper. An alternative would be to put the ADType into the gradient wrapper, and a way to extract that. But I the methods you suggest should do the trick!

The motivation is to use this to implement Gibbs samplers that work without needing explicit knowledge about the other samplers used in the gibbs sampler.

@yebai
Copy link

yebai commented Jul 10, 2024

@sunxd3, can you help create a PR for the proposed API?

@sunxd3
Copy link
Author

sunxd3 commented Jul 19, 2024

As has been discussed in #33, a nicer resolution would be something like getAD(), then we can trigger recreation of the wrapper, while also has the options to use the kwargs. Closing this for now, will start a PR for the aforementioned interface when demands raise.

@sunxd3 sunxd3 closed this as completed Jul 19, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants