Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/run_tests_against_package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -97,5 +97,9 @@ jobs:
export MAXTEXT_ASSETS_ROOT=$(pwd)/src/MaxText/assets
export MAXTEXT_TEST_ASSETS_ROOT=$(pwd)/src/MaxText/test_assets
export MAXTEXT_PKG_DIR=$(pwd)/src/MaxText
# omit this libtpu init args for gpu tests
if [ "${{ inputs.device_type }}" != "cuda12" ]; then
export LIBTPU_INIT_ARGS='--xla_tpu_scoped_vmem_limit_kib=65536'
fi
# TODO: Fix the skipped tests and remove the deselect flags
.venv/bin/python3 -m pytest ${{ inputs.pytest_addopts }} -v -m "${FINAL_PYTEST_MARKER}" --durations=0 --deselect "tests/aot_hlo_identical_test.py::AotHloIdenticalTest::test_default_hlo_match" --deselect "tests/tokenizer_test.py::TokenizerTest::test_detokenize"
5 changes: 3 additions & 2 deletions base_requirements/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ google-cloud-aiplatform
google-cloud-monitoring
grain[parquet]
huggingface_hub
jax!=0.7.1, !=0.7.2
jaxlib!=0.7.1, !=0.7.2
jax
jaxlib
jaxtyping
jsonlines
ml-collections
Expand All @@ -36,6 +36,7 @@ tensorflow-datasets
tensorflow-text
tensorflow
tiktoken
tokamax
transformers
qwix
google-jetstream @ https://github.com/AI-Hypercomputer/JetStream/archive/29329e8e73820993f77cfc8efe34eb2a73f5de98.zip
Expand Down
141 changes: 78 additions & 63 deletions generated_requirements/tpu-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,36 +2,36 @@
# If you need to modify dependencies, please do so in the host requirements file and run seed-env again.

absl-py>=2.3.1
aiofiles>=24.1.0
aiofiles>=25.1.0
aiohappyeyeballs>=2.6.1
aiohttp>=3.13.0
aiohttp>=3.13.1
aiosignal>=1.4.0
annotated-types>=0.7.0
antlr4-python3-runtime>=4.9.3
anyio>=4.11.0
aqtp>=0.9.0
array-record>=0.8.1
astroid>=3.3.11
astroid>=4.0.1
astunparse>=1.6.3
attrs>=25.3.0
auditwheel>=6.4.1
attrs>=25.4.0
auditwheel>=6.4.2
black>=24.10.0
blobfile>=3.1.0
build>=1.2.2.post1
cachetools>=6.2.0
build>=1.3.0
cachetools>=6.2.1
certifi>=2025.10.5
cffi>=2.0.0 ; platform_python_implementation == 'PyPy'
cfgv>=3.4.0
charset-normalizer>=3.4.3
chex>=0.1.90
charset-normalizer>=3.4.4
cheroot>=11.0.0
chex>=0.1.91
click>=8.3.0
cloud-accelerator-diagnostics>=0.1.1
cloud-tpu-diagnostics>=0.1.5
cloudpickle>=3.1.1
clu>=0.0.12
colorama>=0.4.6
contourpy>=1.3.2
coverage>=7.10.7
contourpy>=1.3.3
coverage>=7.11.0
cycler>=0.12.1
datasets>=4.2.0
decorator>=5.2.1
Expand All @@ -41,82 +41,90 @@ dm-tree>=0.1.9
docstring-parser>=0.17.0
editdistance>=0.8.1
einops>=0.8.1
einshape>=1.0
etils>=1.13.0
evaluate>=0.4.6
execnet>=2.1.1
fastapi>=0.118.2
filelock>=3.18.0
flatbuffers>=25.2.10
flax>=0.11.2
fonttools>=4.59.0
fastapi>=0.119.1
filelock>=3.20.0
flatbuffers>=25.9.23
flax>=0.12.0
fonttools>=4.60.1
frozenlist>=1.8.0
fsspec>=2025.7.0
fsspec>=2025.9.0
gast>=0.6.0
gcsfs>=2025.7.0
gcsfs>=2025.9.0
google-api-core>=2.26.0
google-api-python-client>=2.184.0
google-api-python-client>=2.185.0
google-auth-httplib2>=0.2.0
google-auth-oauthlib>=1.2.2
google-auth>=2.41.1
google-cloud-aiplatform>=1.120.0
google-cloud-appengine-logging>=1.6.2
google-cloud-audit-log>=0.3.3
google-benchmark>=1.9.4
google-cloud-aiplatform>=1.122.0
google-cloud-appengine-logging>=1.7.0
google-cloud-audit-log>=0.4.0
google-cloud-bigquery>=3.38.0
google-cloud-core>=2.4.3
google-cloud-logging>=3.12.1
google-cloud-monitoring>=2.27.2
google-cloud-resource-manager>=1.14.2
google-cloud-monitoring>=2.28.0
google-cloud-resource-manager>=1.15.0
google-cloud-storage>=2.19.0
google-crc32c>=1.7.1
google-genai>=1.42.0
google-genai>=1.46.0
google-jetstream @ https://github.com/AI-Hypercomputer/JetStream/archive/29329e8e73820993f77cfc8efe34eb2a73f5de98.zip
google-pasta>=0.2.0
google-resumable-media>=2.7.2
google-tunix>=0.1.1
googleapis-common-protos>=1.70.0
grain>=0.2.12
grpc-google-iam-v1>=0.14.2
google-tunix>=0.1.3
googleapis-common-protos>=1.71.0
grain>=0.2.13
grpc-google-iam-v1>=0.14.3
grpcio-status>=1.71.2
grpcio>=1.75.1
gviz-api>=1.10.0
h11>=0.16.0
h5py>=3.14.0
h5py>=3.15.1
hf-transfer>=0.1.9
hf-xet>=1.1.10 ; platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64'
httpcore>=1.0.9
httplib2>=0.31.0
httpx>=0.28.1
huggingface-hub>=0.35.3
humanize>=4.13.0
hypothesis>=6.136.4
humanize>=4.14.0
hypothesis>=6.142.1
identify>=2.6.15
idna>=3.10
immutabledict>=4.2.1
idna>=3.11
immutabledict>=4.2.2
importlab>=0.8.1
importlib-metadata>=8.7.0
importlib-resources>=6.5.2
iniconfig>=2.1.0
isort>=6.1.0
jax>=0.7.0, !=0.7.1, !=0.7.2
jaxlib>=0.7.0, !=0.7.1, !=0.7.2
isort>=7.0.0
jaraco-functools>=4.3.0
jax-triton>=0.3.0
jax>=0.8.0
jaxlib>=0.8.0
jaxtyping>=0.3.3
jinja2>=3.1.6
joblib>=1.5.2
jsonlines>=4.0.0
kagglehub>=0.3.13
keras>=3.11.3
kiwisolver>=1.4.8
kiwisolver>=1.4.9
libclang>=18.1.1
libcst>=1.8.5
libtpu>=0.0.19
libtpu>=0.0.24 ; platform_machine == 'x86_64' and sys_platform == 'linux'
llvmlite>=0.45.1
lxml>=6.0.2
markdown-it-py>=3.0.0
markdown-it-py>=4.0.0
markdown>=3.9
markupsafe>=3.0.3
matplotlib>=3.10.3
matplotlib>=3.10.7
mccabe>=0.7.0
mdurl>=0.1.2
ml-collections>=1.1.0
ml-dtypes>=0.5.3
ml-goodput-measurement>=0.0.15
mlperf-logging @ https://github.com/mlcommons/logging/archive/38ab22670527888c8eb7825a4ece176fcc36a95d.zip
more-itertools>=10.8.0
mpmath>=1.3.0
msgpack>=1.1.2
Expand All @@ -130,20 +138,23 @@ networkx>=3.5
ninja>=1.13.0
nltk>=3.9.2
nodeenv>=1.9.1
numba>=0.62.1
numpy-typing-compat>=20250818.2.0
numpy>=2.0.2
oauthlib>=3.3.1
omegaconf>=2.3.0
opentelemetry-api>=1.37.0
opentelemetry-api>=1.38.0
opt-einsum>=3.4.0
optax>=0.2.6
optree>=0.17.0
orbax-checkpoint>=0.11.25
optype>=0.14.0
orbax-checkpoint>=0.11.26
packaging>=25.0
pandas>=2.3.3
parameterized>=0.9.0
pathspec>=0.12.1
pathwaysutils>=0.1.3
pillow>=11.3.0
pillow>=12.0.0
platformdirs>=4.5.0
pluggy>=1.6.0
portpicker>=1.6.0
Expand All @@ -153,38 +164,38 @@ promise>=2.3
propcache>=0.4.1
proto-plus>=1.26.1
protobuf>=5.29.5
psutil>=7.0.0
psutil>=7.1.0
pyarrow>=21.0.0
pyasn1-modules>=0.4.2
pyasn1>=0.6.1
pycnite>=2024.7.31
pycparser>=2.23 ; implementation_name != 'PyPy' and platform_python_implementation == 'PyPy'
pycryptodomex>=3.23.0
pydantic-core>=2.41.1
pydantic>=2.12.0
pydantic-core>=2.41.4
pydantic>=2.12.3
pydot>=4.0.1
pyelftools>=0.32
pyglove>=0.4.5
pygments>=2.19.2
pyink>=24.10.1
pylint>=3.3.9
pyparsing>=3.2.3
pylint>=4.0.2
pyparsing>=3.2.5
pyproject-hooks>=1.2.0
pytest-xdist>=3.8.0
pytest>=8.4.1
pytest>=8.4.2
python-dateutil>=2.9.0.post0
python-dotenv>=1.1.1
pytype>=2024.10.11
pytz>=2025.2
pyyaml>=6.0.3
qwix>=0.1.1
regex>=2025.9.18
regex>=2025.10.23
requests-oauthlib>=2.0.0
requests>=2.32.5
rich>=14.1.0
rich>=14.2.0
rsa>=4.9.1
safetensors>=0.6.2
scipy>=1.16.0
scipy-stubs>=1.16.2.4
scipy>=1.16.2
sentencepiece>=0.2.1
seqio>=0.0.20
setuptools>=80.9.0
Expand All @@ -207,29 +218,33 @@ tensorflow-datasets>=4.9.9
tensorflow-metadata>=1.17.2
tensorflow-text>=2.19.0
tensorflow>=2.19.1
tensorstore>=0.1.76
tensorstore>=0.1.78
termcolor>=3.1.0
tiktoken>=0.12.0
tokamax>=0.0.3
tokenizers>=0.22.1
toml>=0.10.2
tomlkit>=0.13.3
toolz>=1.0.0
toolz>=1.1.0
tqdm>=4.67.1
transformers>=4.57.0
transformers>=4.57.1
treescope>=0.1.10
typing-extensions>=4.14.1
triton>=3.5.0
typeguard>=2.13.3
typing-extensions>=4.15.0
typing-inspection>=0.4.2
tzdata>=2025.2
uritemplate>=4.2.0
urllib3>=2.5.0
uvicorn>=0.37.0
virtualenv>=20.34.0
uvicorn>=0.38.0
virtualenv>=20.35.3
wadler-lindig>=0.1.7
websockets>=15.0.1
werkzeug>=3.1.3
wheel>=0.45.1
wrapt>=1.17.3
wrapt>=2.0.0
xprof>=2.20.7
xxhash>=3.6.0
yarl>=1.22.0
zipp>=3.23.0
zstandard>=0.23.0
zstandard>=0.25.0
2 changes: 2 additions & 0 deletions src/MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -632,6 +632,8 @@ mu_dtype: "" # data type to store "mu" of AdamW tracking the first moment. Inher
# Setting nu_dtype is not yet supported by optax, instead nu_dtype is always inherited from weights.
# See b/399961932 for more.

adamw_fused_memory_host_offload: False # Use fused AdamW with memory host offloading.

# Stack trace parameters
collect_stack_trace: False
stack_trace_to_cloud: False # Uploads to cloud logging if True, else to the console if False.
Expand Down
Loading