-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Description
The issue tracker should only be used to report bugs or feature requests. If you are looking for support from other library users, please ask a question on StackOverflow.
Describe the bug
When I try to import the PGD attack from JAX module, I get the following error:
ModuleNotFoundError: No module named 'jax.experimental.stax'
The reason is that the FGSM implementation imports the logsoftmax function from the experimental package in
jax/attacks/fast_gradient_method.py
To Reproduce
Steps to reproduce the behavior:
- Open a google colab IPython notebook
- add the following code
!pip install git+https://github.com/cleverhans-lab/cleverhans.git#egg=cleverhansto install cleverhans - then try to import
from cleverhans.jax.attacks.projected_gradient_descent import projected_gradient_descent - execute and see the error
Expected behavior
the logsoftmax function should be imported from jax.nn package. Change the import to
from jax.nn import log_softmax as logsoftmax and the error would be gone.
Screenshots
If applicable, add screenshots to help explain your problem.
System configuration
- Google colab's default