-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Subtypes of ADGradientWrapper
are defined in package extension, thus unable to be dispatched on
#32
Comments
@sunxd3 It is probably okay to open a PR to discuss this. cc @devmotion |
Dispatching on these types in other packages seems quite brittle to me, regardless of whether they are available or not. These are considered to be internal types whose fields etc. might change at any point. In the context of the DynamicPPL/Turing PR, it seems the safest option would be to just reconstruct the gradient wrapper with |
I did not look into the PR you linked, so I am not sure about the use case, but I agree with @devmotion: it would be very brittle to dispatch on internal types of this package, regardless of whether they are defined in an extension or not. The only symbol we export is Nevertheless, if you need this functionality, we could add API for constructing a gradient from an existing one,
Whether we should use Would that help? |
That would indeed help a lot:) And yeah, I agree with all the reasons for not dispatching on abstract gradient wrapper. An alternative would be to put the ADType into the gradient wrapper, and a way to extract that. But I the methods you suggest should do the trick! The motivation is to use this to implement Gibbs samplers that work without needing explicit knowledge about the other samplers used in the gibbs sampler. |
@sunxd3, can you help create a PR for the proposed API? |
As has been discussed in #33, a nicer resolution would be something like |
There are use cases where users may want to dispatch on a particular subtype of
ADGradientWrapper
for customized behavior based on different AD backend.One example of this is TuringLang/DynamicPPL.jl#626, where we want to be able to dispatch on
ReverseDiffLogDensity
, because we may need to recompute the cached tape. Now becauseReverseDiffLogDensity
is defined in package extension, it is invisible.Would it make sense to move all (or some) of the definition of
ADGradientWrapper
subtypes into the main package to expose the name of the wrapper?The text was updated successfully, but these errors were encountered: