-
Notifications
You must be signed in to change notification settings - Fork 25
Autograd -> Jax conversion #433
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Conversation
Wavebot is working now when converted to jax. Working on updating the code so that only the necessary functions use jax.numpy while the rest use numpy. |
* bug bix : DC and Nyquist frequency should not be devided by two before ifft * Changed td_to_fd to scale single sided frequency components rather than TD signal * minor bug fix from issue332 sandialabs#332
* added initial file changes based on sphinx_multiversion docs and WEC-Sim implementation * removed sphinx-multiversion since it is no longer supported and made manual multiversion * now uses absolute paths, commented out linkcheck for debugging * fixed docstring errors in utilities module * updating files again that somehow got reverted * fixing path in conf.py * don't run tutorials (will revert later) * handle file moves correctly, fixed if statement to make other versions appear * fixed two bugs in versions template * reverted temp changes, changes latest to main * switched latest to main * main branch now in root directory of pages * fixed URLs with change from last commit * make other branches visible before building * switched main branch tag for more testing * fixed typo * switched dev branch to an existing branch * renamed main to latest, changed version.html file name to avoid confusion * added prints about moving files so Sphinx output isn't misleading * fixed typo with quotations * changed versions.html name back because that broke things I guess * modified contributing documentation to reflect changes * add logic to remove duplicate 'latest' branch * Fixed pathing when already on latest * remove typo * Troubleshooting complete, switching back to correct branches for deployment * Removed extra word in docstring * removed redundant function * fixed pathing so returns to same file (and fixes tutorial/API docs) * changed latest branch for demonstration * switched back latest branch for deployment
* removed conda environment from workflows since newer capytaine/wavespectra work with Windows * fixed unnecessary capitalization * still create CI conda environment to fix Mac environment failures * added conda env fully back in, push workflow deploys docs, split PR workflow * conda environment activates again * mambaforge instead of miniforge * manual cache reset * reset to older version of setup-miniconda to troubleshoot
…supported versions (sandialabs#390) Co-authored-by: jtgrasb <[email protected]>
* Try specifying subversion * Test new cache * revert to 3.12 * Revert comment back to normal
0c0e42b
to
7ecc18d
Compare
To do:
|
I added just-in-time compilation to the optimization (objective function, constraints, and relative gradients) using jax.jit which should speed up the code. I also had to change the call to Here is the computation time with and without jit for the AquaHarmonics parameter sweep cell:
Based on this issue thread, scipy does computation on numpy arrays which means the data type keeps getting converted back and forth between jax and numpy arrays. This is why the jax without jit has such a large computation time and the jax with jit computation time is not so much faster than autograd. jax.vmap:
To do:
|
@cmichelenstrofer - I'm trying to debug the CI on this PR for macOS. The key part of the failing CI log seems to be the following, which is either due to some mismatch of
When I install locally, I do the following trying to replicate the install commands in pr.yml.
This seems to work on the GitHub runner (see, e.g., https://github.com/sandialabs/WecOptTool/actions/runs/17650331026/job/50178598578), but my machine tells me there aren't capytaine and wavespectra binaries for arm64 on conda-forge (see #324 for our previous discussion on this). So I instead do the following, which works fine. mamba create -n tmp_wot
mamba activate tmp_wot
mamba create -n tmp_wot
mamba install pip
cd wecopttool
pip install .
pip install gmsh pygmsh coveralls pytest # we could probably make a special set of optional dependencies in pyproject.toml to avoid this line
coverage run -m pytest Based on all of this, I pushed an update to pr.yml with f1e6e51 to basically do the whole installation via pip. Note that this bypasses the installation of capytaine and wavespectra via mamba which I believe you added to allow for caching (see #35)1 It seems to fix the problem on macOS, Ubuntu still works fine, and now Windows fails due a Segmentation fault 🤪 -- maybe this is due to these warnings which we have actually been receiving for a while about ![]() Footnotes
|
For some reason, fully removing mamba and pip installing in editable mode using |
Description
Convert autograd to jax.
Wavebot tutorial is working for the first optimization.