Skip to content

Conversation

longye-tian
Copy link
Contributor

@longye-tian longye-tian commented Aug 20, 2025

Hi @mmcky @HumphreyYang
This PR updates the Job Search III (Fitted Value Function Iteration) lecture to use JAX instead of NumPy/Numba, following modern QuantEcon standards and fixing style guide violations.

Reference to Issue #502

Changes

Style Fixes (per QuantEcon style guide)

  • Fixed section heading capitalization (Title Rule):
    • ## The Algorithm## The algorithm
    • ### Value Function Iteration### Value function iteration
    • ### Fitted Value Function Iteration### Fitted value function iteration

JAX Conversion

  • Replaced NumPy/Numba imports with JAX equivalents
  • Converted @jitclass to NamedTuple with factory function pattern
  • Updated random number generation to use JAX's PRNG keys
  • Replaced imperative loops with jax.lax.while_loop
  • Implemented functional programming patterns in exercises using jax.vmap
  • Updated all array operations from NumPy to JAX

Text Consistency

  • Updated documentation references from numpy.interp to jnp.interp
  • Fixed code references: np.infjnp.inf, np.linspacejnp.linspace

Copy link

github-actions bot commented Aug 20, 2025

@github-actions github-actions bot temporarily deployed to pull request August 20, 2025 10:20 Inactive
@github-actions github-actions bot temporarily deployed to pull request August 20, 2025 10:20 Inactive
Copy link
Member

@HumphreyYang HumphreyYang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Many thanks @longye-tian, it looks really nice!

I just pushed some minor fix of typos.

Please see some very minor suggestions below (most are just cosmetic)!

@github-actions github-actions bot temporarily deployed to pull request August 20, 2025 12:27 Inactive
@github-actions github-actions bot temporarily deployed to pull request August 20, 2025 12:29 Inactive
@mmcky mmcky added the lecture label Aug 21, 2025
@github-actions github-actions bot temporarily deployed to pull request August 21, 2025 01:29 Inactive
@github-actions github-actions bot temporarily deployed to pull request August 21, 2025 01:30 Inactive
@mmcky mmcky requested a review from Copilot August 21, 2025 05:54
Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR modernizes the Job Search III (Fitted Value Function Iteration) lecture by converting from NumPy/Numba to JAX, implementing functional programming patterns and fixing style guide violations. The conversion improves compatibility with modern QuantEcon standards while maintaining the lecture's educational content.

Key changes:

  • Converts NumPy/Numba code to JAX with functional programming patterns
  • Fixes section heading capitalization per QuantEcon style guide
  • Updates random number generation to use JAX's PRNG system

@github-actions github-actions bot temporarily deployed to pull request August 21, 2025 06:25 Inactive
@mmcky
Copy link
Contributor

mmcky commented Aug 21, 2025

Thanks @longye-tian for an excellent PR, thanks @HumphreyYang for comments. I just had a few questions.

  1. the figure in exercise 2 is slightly different. This might be OK -- but just thought I would check if this is due to a random seed, or if it has any implications for the exercise (LHS = this PR, RHS = Live site)
Screenshot 2025-08-21 at 4 50 25 pm

2. The timing is slower on JAX version (GitHub actions build - gpu) than the numba version. (LHS = this PR, RHS = Live site) (Status: Local testing shows jax should run approximately 2 x faster than numba version, so this timing is an issue with GA environment). #567 (comment)

Screenshot 2025-08-21 at 4 53 16 pm
  • review collab comments to either agree or resolve.

@github-actions github-actions bot temporarily deployed to pull request August 21, 2025 07:14 Inactive
@github-actions github-actions bot temporarily deployed to pull request August 21, 2025 07:15 Inactive
@mmcky
Copy link
Contributor

mmcky commented Aug 22, 2025

@HumphreyYang I see 95f5ef6 the result is 10s so 2s faster than gpu and is the same as the numba implementation.

@HumphreyYang
Copy link
Member

@HumphreyYang I see 95f5ef6 the result is 10s so 2s faster than gpu and is the same as the numba implementation.

Hi @mmcky,

This is an interesting behaviour. I am curious could it be because of the jb build issue.

I will run some more benchmark.

@mmcky
Copy link
Contributor

mmcky commented Aug 22, 2025

@HumphreyYang I see 95f5ef6 the result is 10s so 2s faster than gpu and is the same as the numba implementation.

Hi @mmcky,

This is an interesting behaviour. I am curious could it be because of the jb build issue.

I will run some more benchmark.

So I ran the main branch and this branch locally

state time
main (numba) 6s
job_separation (jax) 3s

so it does appear to be a GitHub Actions thing.
Surprising as we are running this on an g4dn.2xlarge which has 8 vCPUs with 64Gb RAM.

@HumphreyYang
Copy link
Member

Hi @mmcky,

Please find the gist result from my end:

https://gist.github.com/HumphreyYang/a488df0e8ef01b8d91af6e6309d5e45f

It should run around 2x faster than the numba version in both CPU and GPU.

There are while loops in the code so I expect GPU wouldn't help much.

@HumphreyYang
Copy link
Member

@HumphreyYang I see 95f5ef6 the result is 10s so 2s faster than gpu and is the same as the numba implementation.

Hi @mmcky,
This is an interesting behaviour. I am curious could it be because of the jb build issue.
I will run some more benchmark.

So I ran the main branch and this branch locally

state time
main (numba) 6s
job_separation (jax) 3s
so it does appear to be a GitHub Actions thing. Surprising as we are running this on an g4dn.2xlarge which has 8 vCPUs with 64Gb RAM.

Great! Our number matches! It looks like 8vCPUs are not helping : )

@mmcky
Copy link
Contributor

mmcky commented Aug 22, 2025

@longye-tian I have updated the timing comment above (with strikethrough) based on our investigations and opened #568

@HumphreyYang
Copy link
Member

HumphreyYang commented Aug 22, 2025

  1. the figure in exercise 2 is slightly different. This might be OK -- but just thought I would check if this is due to a random seed, or if it has any implications for the exercise (LHS = this PR, RHS = Live site)

Thanks @mmcky for the detailed review:

JAX and NumPy use different pseudorandom number generators (PRNGs), so even with the same seed, they produce different random numbers. Below is an illustration:
https://gist.github.com/HumphreyYang/f8aa7b322c48723bac568290de8a08fc

@mmcky
Copy link
Contributor

mmcky commented Aug 22, 2025

  1. the figure in exercise 2 is slightly different. This might be OK -- but just thought I would check if this is due to a random seed, or if it has any implications for the exercise (LHS = this PR, RHS = Live site)

Thanks @mmcky for the detailed review:

JAX and NumPy use different pseudorandom number generators (PRNGs), so even with the same seed, they produce different random numbers. Below is an illustration: https://gist.github.com/HumphreyYang/f8aa7b322c48723bac568290de8a08fc

Thanks @HumphreyYang so once we change to jax that graph will become static again (with the new shape)? Also just to confirm, the new shape doesn't change the interpretation in any way re: context of the solution?

@HumphreyYang
Copy link
Member

  1. the figure in exercise 2 is slightly different. This might be OK -- but just thought I would check if this is due to a random seed, or if it has any implications for the exercise (LHS = this PR, RHS = Live site)

Thanks @mmcky for the detailed review:
JAX and NumPy use different pseudorandom number generators (PRNGs), so even with the same seed, they produce different random numbers. Below is an illustration: https://gist.github.com/HumphreyYang/f8aa7b322c48723bac568290de8a08fc

Thanks @HumphreyYang so once we change to jax that graph will become static again (with the new shape)? Also just to confirm, the new shape doesn't change the interpretation in any way re: context of the solution?

Hi @mmcky, yes, the new PRNGs should output the same solution everytime and the interpretation is not changed. Everything is in the same magnitude as the old code so the interpretations are the same.

@mmcky mmcky marked this pull request as ready for review August 22, 2025 04:19
@mmcky mmcky added the ready label Aug 22, 2025
@github-actions github-actions bot temporarily deployed to pull request August 22, 2025 04:50 Inactive
@github-actions github-actions bot temporarily deployed to pull request August 22, 2025 04:50 Inactive
@mmcky
Copy link
Contributor

mmcky commented Aug 22, 2025

thanks @longye-tian for an excellent PR.

thanks @HumphreyYang for your comments and reviews.

@jstac this will serve as another reference in the numba -> jax conversion work.

@mmcky mmcky requested a review from jstac August 22, 2025 06:24
@mmcky
Copy link
Contributor

mmcky commented Aug 22, 2025

@jstac I tagged your review for this one so you have 👀 on what the team has put together.

I'll merge this on Monday.

@mmcky mmcky merged commit c6ae618 into main Aug 25, 2025
7 checks passed
@mmcky mmcky deleted the job_separation branch August 25, 2025 02:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants