-
Notifications
You must be signed in to change notification settings - Fork 17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Consider unifying approach to PRNG state #980
Comments
Attempting to map out a decision tree:
|
Thanks for mapping these out. By default I lean towards (2) due to familiarity, imposing looser requirements (maybe one class uses cuRAND, a different class only uses numpy, ...), and since it seems harder to make certain kinds of errors (forgetting to split keys etc.). But (1) does seem cleaner. An additional practice that may be compatible with all of the above options is to implement a random function |
That's a good point that forgetting to split keys would be a class of error unique to (1b) (added to "disadvantages" above).
This seems similar in spirit to option (1) to me, but does have the benefit of being RNG-agnostic. It seems like it does introduce some additional room for error, though. E.g. the caller must be careful to ensure that the input randomness has the expected distribution, and because generation is decoupled from transformation, it might be harder to keep in sync. |
That's true, and this is also a realistic concern in the context of reweighting, where we might have a deterministic function |
So it doesn't get lost, some further observations from @mcwitt in : #1128 (comment) |
Thanks to @mcwitt for thoughtful comments: migrating from #978 (comment) . Would be good to discuss and adopt project-wide conventions, if possible.
Some approaches currently used:
Some possible trade-offs:
See also: https://jax.readthedocs.io/en/latest/jax-101/05-random-numbers.html
The text was updated successfully, but these errors were encountered: