-
-
Notifications
You must be signed in to change notification settings - Fork 53
[mccall_fitted_vfi] JAX conversion #567
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Many thanks @longye-tian, it looks really nice!
I just pushed some minor fix of typos.
Please see some very minor suggestions below (most are just cosmetic)!
Co-authored-by: Humphrey Yang <[email protected]>
Co-authored-by: Humphrey Yang <[email protected]>
Co-authored-by: Humphrey Yang <[email protected]>
Co-authored-by: Humphrey Yang <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR modernizes the Job Search III (Fitted Value Function Iteration) lecture by converting from NumPy/Numba to JAX, implementing functional programming patterns and fixing style guide violations. The conversion improves compatibility with modern QuantEcon standards while maintaining the lecture's educational content.
Key changes:
- Converts NumPy/Numba code to JAX with functional programming patterns
- Fixes section heading capitalization per QuantEcon style guide
- Updates random number generation to use JAX's PRNG system
Co-authored-by: Copilot <[email protected]>
Thanks @longye-tian for an excellent PR, thanks @HumphreyYang for comments. I just had a few questions.
![]()
![]()
|
@HumphreyYang I see 95f5ef6 the result is 10s so 2s faster than |
Hi @mmcky, This is an interesting behaviour. I am curious could it be because of the I will run some more benchmark. |
So I ran the
so it does appear to be a GitHub Actions thing. |
Hi @mmcky, Please find the gist result from my end: https://gist.github.com/HumphreyYang/a488df0e8ef01b8d91af6e6309d5e45f It should run around 2x faster than the numba version in both CPU and GPU. There are while loops in the code so I expect GPU wouldn't help much. |
Great! Our number matches! It looks like 8vCPUs are not helping : ) |
@longye-tian I have updated the timing comment above (with strikethrough) based on our investigations and opened #568 |
Thanks @mmcky for the detailed review: JAX and NumPy use different pseudorandom number generators (PRNGs), so even with the same seed, they produce different random numbers. Below is an illustration: |
Thanks @HumphreyYang so once we change to |
Hi @mmcky, yes, the new PRNGs should output the same solution everytime and the interpretation is not changed. Everything is in the same magnitude as the old code so the interpretations are the same. |
thanks @longye-tian for an excellent PR. thanks @HumphreyYang for your comments and reviews. @jstac this will serve as another reference in the |
@jstac I tagged your review for this one so you have 👀 on what the team has put together. I'll merge this on Monday. |
Hi @mmcky @HumphreyYang
This PR updates the Job Search III (Fitted Value Function Iteration) lecture to use JAX instead of NumPy/Numba, following modern QuantEcon standards and fixing style guide violations.
Reference to Issue #502
Changes
Style Fixes (per QuantEcon style guide)
## The Algorithm
→## The algorithm
### Value Function Iteration
→### Value function iteration
### Fitted Value Function Iteration
→### Fitted value function iteration
JAX Conversion
@jitclass
toNamedTuple
with factory function patternjax.lax.while_loop
jax.vmap
Text Consistency
numpy.interp
tojnp.interp
np.inf
→jnp.inf
,np.linspace
→jnp.linspace