Fix regression with metrics passed to compile. (#22663)
#2
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| name: Keras GPU Tests | |
| on: | |
| push: | |
| branches: [master] | |
| pull_request: | |
| types: [unlabeled] | |
| release: | |
| types: [created] | |
| permissions: | |
| contents: read | |
| concurrency: | |
| group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} | |
| cancel-in-progress: true | |
| jobs: | |
| test-in-container: | |
| name: Run tests on GPU | |
| runs-on: linux-x86-g2-16-l4-1gpu | |
| # Only run on pushes to master, releases or "kokoro:force-run" unlabel | |
| if: | | |
| github.event_name == 'push' || | |
| github.event_name == 'release' || | |
| (github.event.action == 'unlabeled' && github.event.label.name == 'kokoro:force-run') | |
| strategy: | |
| fail-fast: false | |
| matrix: | |
| backend: [jax, tensorflow, torch] | |
| container: | |
| image: python:3.11-slim | |
| options: --privileged --network host | |
| steps: | |
| - name: Checkout ${{ github.ref }} | |
| uses: actions/checkout@v6 | |
| - name: Check CUDA Version | |
| run: nvidia-smi | |
| - name: Install Torch Prerequisites | |
| # Torch Dynamo / triton requires the C++ compiler to be installed, which is part of `build-essential`. | |
| if: ${{ matrix.backend == 'torch'}} | |
| run: | | |
| apt-get update | |
| apt-get -y install build-essential | |
| - name: Install Dependencies | |
| run: pip install --no-cache-dir -r requirements-${{ matrix.backend }}-cuda.txt | |
| - name: Set Keras Backend | |
| run: echo "KERAS_BACKEND=${{ matrix.backend }}" >> $GITHUB_ENV | |
| - name: Verify TF Installation | |
| if: ${{ matrix.backend == 'tensorflow'}} | |
| run: python3 -c "import tensorflow as tf; print('Tensorflow devices:', tf.config.list_logical_devices()); assert len(tf.config.list_physical_devices('GPU')) > 0" | |
| - name: Verify JAX Installation | |
| if: ${{ matrix.backend == 'jax'}} | |
| run: python3 -c "import jax; print('JAX devices:', jax.devices()); assert jax.default_backend() == 'gpu'" | |
| - name: Verify Torch Installation | |
| if: ${{ matrix.backend == 'torch'}} | |
| run: python3 -c "import torch; print('Torch devices:', [torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())]); assert torch.cuda.device_count() > 0" | |
| - name: Run Tests | |
| run: pytest keras --ignore keras/src/applications --cov=keras --cov-config=pyproject.toml | |
| - name: Run Distribution Tests | |
| if: ${{ matrix.backend == 'jax'}} | |
| run: pytest keras/src/distribution/distribution_lib_test.py --cov=keras --cov-config=pyproject.toml |