Skip to content

Conversation

jtgrasb
Copy link
Collaborator

@jtgrasb jtgrasb commented Aug 11, 2025

Description

Convert autograd to jax.

Wavebot tutorial is working for the first optimization.

@jtgrasb
Copy link
Collaborator Author

jtgrasb commented Aug 12, 2025

Wavebot is working now when converted to jax. Working on updating the code so that only the necessary functions use jax.numpy while the rest use numpy.

cmichelenstrofer and others added 28 commits August 20, 2025 13:00
* bug bix : DC and Nyquist frequency should not be devided by two before ifft

* Changed td_to_fd to scale single sided frequency components rather than TD signal

* minor bug fix from issue332 sandialabs#332
* added initial file changes based on sphinx_multiversion docs and WEC-Sim implementation

* removed sphinx-multiversion since it is no longer supported and made manual multiversion

* now uses absolute paths, commented out linkcheck for debugging

* fixed docstring errors in utilities module

* updating files again that somehow got reverted

* fixing path in conf.py

* don't run tutorials (will revert later)

* handle file moves correctly, fixed if statement to make other versions appear

* fixed two bugs in versions template

* reverted temp changes, changes latest to main

* switched latest to main

* main branch now in root directory of pages

* fixed URLs with change from last commit

* make other branches visible before building

* switched main branch tag for more testing

* fixed typo

* switched dev branch to an existing branch

* renamed main to latest, changed version.html file name to avoid confusion

* added prints about moving files so Sphinx output isn't misleading

* fixed typo with quotations

* changed versions.html name back because that broke things I guess

* modified contributing documentation to reflect changes

* add logic to remove duplicate 'latest' branch

* Fixed pathing when already on latest

* remove typo

* Troubleshooting complete, switching back to correct branches for deployment

* Removed extra word in docstring

* removed redundant function

* fixed pathing so returns to same file (and fixes tutorial/API docs)

* changed latest branch for demonstration

* switched back latest branch for deployment
* removed conda environment from workflows since newer capytaine/wavespectra work with Windows

* fixed unnecessary capitalization

* still create CI conda environment to fix Mac environment failures

* added conda env fully back in, push workflow deploys docs, split PR workflow

* conda environment activates again

* mambaforge instead of miniforge

* manual cache reset

* reset to older version of setup-miniconda to troubleshoot
* Try specifying subversion

* Test new cache

* revert to 3.12

* Revert comment back to normal
@jtgrasb
Copy link
Collaborator Author

jtgrasb commented Aug 21, 2025

To do:

  • Add in functions such as vmap, grad, and jit to increase code performance.
  • Resolve MacOS issues on GitHub actions.

@jtgrasb
Copy link
Collaborator Author

jtgrasb commented Sep 11, 2025

I added just-in-time compilation to the optimization (objective function, constraints, and relative gradients) using jax.jit which should speed up the code. I also had to change the call to block_diag() in the mimo_transfer_mat() function to a revised function that is now jittable. There were a couple of assert statements that I commented out for now because assertions related to the dynamic value (some were able to be left in because they were only checking static properties, which is jittable).

Here is the computation time with and without jit for the AquaHarmonics parameter sweep cell:

Implementation Time
Autograd ~ 38 s
Jax without jit ~130 s
Jax with jit ~20 s

Based on this issue thread, scipy does computation on numpy arrays which means the data type keeps getting converted back and forth between jax and numpy arrays. This is why the jax without jit has such a large computation time and the jax with jit computation time is not so much faster than autograd.

jax.vmap:

  • Would need to write functions differently to vectorize. Would be good but a lot of work - save for future work.

To do:

  • Resolve MacOS issues on GitHub actions.

@rgcoe
Copy link
Collaborator

rgcoe commented Sep 29, 2025

@cmichelenstrofer - I'm trying to debug the CI on this PR for macOS. The key part of the failing CI log seems to be the following, which is either due to some mismatch of jax and jaxlib (see jax-ml/jax#14036) and/or something about arm64 (I think this seems likely based on my work below). I'd like your help thinking about this and deciding whether our current workflow of installing partially with mamba and partially with pip really makes sense.

E   AttributeError: partially initialized module 'jax' has no attribute 'version' (most likely due to a circular import)
=========================== short test summary info ============================
ERROR tests/test_core.py - RuntimeError: This version of jaxlib was built using AVX instructions, which your CPU and/or operating system do not support. This error is frequently encountered on macOS when running an x86 Python installation on ARM hardware. In this case, try installing an ARM build of Python. Otherwise, you may be able work around this issue by building jaxlib from source.
ERROR tests/test_integration.py - AttributeError: partially initialized module 'jax' has no attribute 'version' (most likely due to a circular import)
ERROR tests/test_pto.py - AttributeError: partially initialized module 'jax' has no attribute 'version' (most likely due to a circular import)
ERROR tests/test_utilities.py - AttributeError: partially initialized module 'jax' has no attribute 'version' (most likely due to a circular import)
ERROR tests/test_waves.py - AttributeError: partially initialized module 'jax' has no attribute 'version' (most likely due to a circular import)
!!!!!!!!!!!!!!!!!!! Interrupted: 5 errors during collection !!!!!!!!!!!!!!!!!!!!
========================= 5 errors in 99.80s (0:01:39) =========================

When I install locally, I do the following trying to replicate the install commands in pr.yml.

mamba create -n tmp_wot
mamba activate tmp_wot
mamba install python=3.12 capytaine wavespectra

This seems to work on the GitHub runner (see, e.g., https://github.com/sandialabs/WecOptTool/actions/runs/17650331026/job/50178598578), but my machine tells me there aren't capytaine and wavespectra binaries for arm64 on conda-forge (see #324 for our previous discussion on this). So I instead do the following, which works fine.

mamba create -n tmp_wot
mamba activate tmp_wot
mamba create -n tmp_wot
mamba install pip
cd wecopttool
pip install .
pip install gmsh pygmsh coveralls pytest # we could probably make a special set of optional dependencies in pyproject.toml to avoid this line
coverage run -m pytest

Based on all of this, I pushed an update to pr.yml with f1e6e51 to basically do the whole installation via pip. Note that this bypasses the installation of capytaine and wavespectra via mamba which I believe you added to allow for caching (see #35)1 It seems to fix the problem on macOS, Ubuntu still works fine, and now Windows fails due a Segmentation fault 🤪 -- maybe this is due to these warnings which we have actually been receiving for a while about cygpath?

image

Footnotes

  1. note that installation via pip on macOS and Ubuntu takes only 1-2 minutes (Windows is 8 minutes)

@jtgrasb
Copy link
Collaborator Author

jtgrasb commented Oct 3, 2025

For some reason, fully removing mamba and pip installing in editable mode using -e worked to get the tests passing on my own fork, but didn't work for this PR and the windows test is still failing.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants