Skip to content

Conversation

@sanepunk
Copy link
Contributor

@sanepunk sanepunk commented Oct 19, 2025

Description

This PR adds a complete implementation of the REINFORCE policy gradient algorithm using Flax NNX for the CartPole-v1 environment.

Motivation

  • Demonstrates how to implement reinforcement learning algorithms with Flax NNX
  • Provides a canonical example of policy gradient methods
  • Serves as an entry point for users interested in RL with Flax
  • Complements existing supervised learning examples in the repository

Implementation Details

Algorithm: REINFORCE (Williams, 1992) - Monte Carlo policy gradient method

Architecture:

  • 3-layer MLP policy network (4→128→128→2)
  • Leaky ReLU activations
  • Xavier normal weight initialization
  • Categorical action distribution

Training:

  • Adam optimizer with exponential learning rate decay
  • Gradient clipping (global norm = 1.0)
  • Return normalization for stable training
  • Discount factor γ = 0.99

Environment: CartPole-v1 via Gymnax

What's Included

  • examples/reinforce/simple_reinforce.ipynb - Complete Jupyter notebook with:
    • Environment setup
    • Policy network definition using NNX modules
    • Training loop with progress tracking
    • Visualization of training curves and agent behavior
  • examples/reinforce/README.md - Comprehensive documentation
  • examples/reinforce/requirements.txt - All dependencies
  • examples/reinforce/training_rewards.png - Training curve visualization
  • examples/reinforce/anim.gif - Trained agent animation

Performance

  • Training time: ~12 minutes on CPU
  • Episodes to solve: ~300-400 episodes
  • Final average reward: 490+ (solving threshold: 475)
  • Throughput: 1.38 episodes/second

@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Oct 20, 2025

@sanepunk thanks for the PR. I'll ask the team whether we would like to have a reinforcement learning example in flax repository (the maintenance of this code etc) and let you know here about the decision.

@sanepunk
Copy link
Contributor Author

Thanks, @vfdev-5! I noticed there wasn’t an NNX-based RL example in the repo, so I thought it might be helpful to contribute one. Looking forward to hearing the team’s decision.

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Oct 20, 2025

@sanepunk unfortunately, for now it was decided not to add new RL example to flax examples. This is mainly due to multiple reasons like:

  • we should refresh and update existing examples such they use the latest Flax API (e.g. >=0.12.0)
  • your code maintenance costs in future: keep versions up to date and fixing all BC-breaking changes. For example, previously, the change from gym to gymnasium and BC-breaking changes were a bit painful to dig into.
  • discussable dependency: gymnax which has its last commit 5 months ago and whether we can rely on it.

Again thanks a lot for this RL example with Flax! I understand that this decision can be disappointing as you already done a decent amount of work, but please consider a good practice is to open first an issue in the repository about what people would like to add as contribution and get a feedback from maintainers.

@sanepunk
Copy link
Contributor Author

@vfdev-5 Thank you so much for taking the time to review and explain this! I really appreciate the feedback and will definitely open an issue first next time.

Would it be okay if I open an issue to discuss adding NNX implementations to the existing nn module examples, since it could help people understand and learn the NNX module better?

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Oct 20, 2025

Would it be okay if I open an issue to discuss adding NNX implementations to the existing nn module examples, since it could help people understand and learn the NNX module better?

Please open first a new issue to discuss first what would you like to do. I would say the priority is to make refresh the documentation, guides etc.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants