Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions base_requirements/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ flax
gcsfs
google-api-python-client
google-cloud-aiplatform
google-cloud-mldiagnostics
google-cloud-monitoring
grain[parquet]
huggingface_hub
Expand Down
168 changes: 90 additions & 78 deletions generated_requirements/cuda12-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,115 +2,121 @@
# If you need to modify dependencies, please do so in the host requirements file and run seed-env again.

absl-py>=2.3.1
aiofiles>=24.1.0
aiofiles>=25.1.0
aiohappyeyeballs>=2.6.1
aiohttp>=3.13.0
aiohttp>=3.13.1
aiosignal>=1.4.0
annotated-doc>=0.0.3
annotated-types>=0.7.0
antlr4-python3-runtime>=4.9.3
anyio>=4.11.0
aqtp>=0.9.0
array-record>=0.8.1
astroid>=3.3.11
astroid>=4.0.1
astunparse>=1.6.3
attrs>=25.3.0
auditwheel>=6.4.1
attrs>=25.4.0
auditwheel>=6.4.2
black>=24.10.0
blobfile>=3.1.0
build>=1.2.2.post1
cachetools>=6.2.0
build>=1.3.0
cachetools>=6.2.1
certifi>=2025.10.5
cffi>=2.0.0 ; platform_python_implementation == 'PyPy'
cfgv>=3.4.0
charset-normalizer>=3.4.3
chex>=0.1.90
charset-normalizer>=3.4.4
cheroot>=11.0.0
chex>=0.1.91
click>=8.3.0
cloud-accelerator-diagnostics>=0.1.1
cloud-tpu-diagnostics>=0.1.5
cloudpickle>=3.1.1
clu>=0.0.12
colorama>=0.4.6
contourpy>=1.3.2
coverage>=7.10.7
contourpy>=1.3.3
coverage>=7.11.0
cycler>=0.12.1
datasets>=4.2.0
datasets>=4.3.0
decorator>=5.2.1
dill>=0.4.0
distlib>=0.4.0
dm-tree>=0.1.9
docstring-parser>=0.17.0
editdistance>=0.8.1
einops>=0.8.1
einshape>=1.0
etils>=1.13.0
evaluate>=0.4.6
execnet>=2.1.1
fastapi>=0.118.2
filelock>=3.18.0
flatbuffers>=25.2.10
flax>=0.11.2
fonttools>=4.59.0
fastapi>=0.120.0
filelock>=3.20.0
flatbuffers>=25.9.23
flax>=0.12.0
fonttools>=4.60.1
frozenlist>=1.8.0
fsspec>=2025.7.0
fsspec>=2025.9.0
gast>=0.6.0
gcsfs>=2025.7.0
google-api-core>=2.26.0
google-api-python-client>=2.184.0
gcsfs>=2025.9.0
google-api-core>=2.27.0
google-api-python-client>=2.185.0
google-auth-httplib2>=0.2.0
google-auth-oauthlib>=1.2.2
google-auth>=2.41.1
google-cloud-aiplatform>=1.120.0
google-cloud-appengine-logging>=1.6.2
google-cloud-audit-log>=0.3.3
google-benchmark>=1.9.4
google-cloud-aiplatform>=1.122.0
google-cloud-appengine-logging>=1.7.0
google-cloud-audit-log>=0.4.0
google-cloud-bigquery>=3.38.0
google-cloud-core>=2.4.3
google-cloud-logging>=3.12.1
google-cloud-monitoring>=2.27.2
google-cloud-resource-manager>=1.14.2
google-cloud-mldiagnostics>=0.3.1
google-cloud-monitoring>=2.28.0
google-cloud-resource-manager>=1.15.0
google-cloud-storage>=2.19.0
google-crc32c>=1.7.1
google-genai>=1.42.0
google-genai>=1.46.0
google-pasta>=0.2.0
google-resumable-media>=2.7.2
googleapis-common-protos>=1.70.0
grain>=0.2.12
grpc-google-iam-v1>=0.14.2
googleapis-common-protos>=1.71.0
grain>=0.2.13
grpc-google-iam-v1>=0.14.3
grpcio-status>=1.71.2
grpcio>=1.75.1
gviz-api>=1.10.0
h11>=0.16.0
h5py>=3.14.0
h5py>=3.15.1
hf-xet>=1.1.10 ; platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64'
httpcore>=1.0.9
httplib2>=0.31.0
httpx>=0.28.1
huggingface-hub>=0.35.3
humanize>=4.13.0
hypothesis>=6.136.4
huggingface-hub>=0.36.0
humanize>=4.14.0
hypothesis>=6.142.1
identify>=2.6.15
idna>=3.10
immutabledict>=4.2.1
idna>=3.11
immutabledict>=4.2.2
importlab>=0.8.1
importlib-metadata>=8.7.0
importlib-resources>=6.5.2
iniconfig>=2.1.0
isort>=6.1.0
jax-cuda12-pjrt>=0.7.0, !=0.7.1, !=0.7.2 ; sys_platform == 'linux'
jax-cuda12-plugin>=0.7.0, !=0.7.1, !=0.7.2 ; sys_platform == 'linux'
jax>=0.7.0, !=0.7.1, !=0.7.2
jaxlib>=0.7.0, !=0.7.1, !=0.7.2
isort>=7.0.0
jaraco-functools>=4.3.0
jax-cuda12-pjrt>=0.8.0 ; sys_platform == 'linux'
jax-cuda12-plugin>=0.8.0 ; sys_platform == 'linux'
jax-triton>=0.3.0
jax>=0.8.0
jaxlib>=0.8.0
jaxtyping>=0.3.3
jinja2>=3.1.6
joblib>=1.5.2
jsonlines>=4.0.0
keras>=3.11.3
kiwisolver>=1.4.8
kiwisolver>=1.4.9
libclang>=18.1.1
libcst>=1.8.5
lxml>=6.0.2
markdown-it-py>=3.0.0
markdown-it-py>=4.0.0
markdown>=3.9
markupsafe>=3.0.3
matplotlib>=3.10.3
matplotlib>=3.10.7
mccabe>=0.7.0
mdurl>=0.1.2
ml-collections>=1.1.0
Expand All @@ -129,32 +135,34 @@ networkx>=3.5
ninja>=1.13.0
nltk>=3.9.2
nodeenv>=1.9.1
numpy-typing-compat>=20250818.2.0
numpy>=2.0.2
nvidia-cublas-cu12>=12.9.1.4
nvidia-cuda-cupti-cu12>=12.9.79
nvidia-cuda-nvcc-cu12>=12.9.86
nvidia-cublas-cu12>=12.9.1.4 ; sys_platform == 'linux'
nvidia-cuda-cupti-cu12>=12.9.79 ; sys_platform == 'linux'
nvidia-cuda-nvcc-cu12>=12.9.86 ; sys_platform == 'linux'
nvidia-cuda-nvrtc-cu12>=12.9.86 ; sys_platform == 'linux'
nvidia-cuda-runtime-cu12>=12.9.79
nvidia-cudnn-cu12>=9.11.0.98
nvidia-cufft-cu12>=11.4.1.4
nvidia-cusolver-cu12>=11.7.5.82
nvidia-cusparse-cu12>=12.5.10.65
nvidia-nccl-cu12>=2.27.6
nvidia-nvjitlink-cu12>=12.9.86
nvidia-nvshmem-cu12>=3.3.9 ; sys_platform == 'linux'
nvidia-cuda-runtime-cu12>=12.9.79 ; sys_platform == 'linux'
nvidia-cudnn-cu12>=9.14.0.64 ; sys_platform == 'linux'
nvidia-cufft-cu12>=11.4.1.4 ; sys_platform == 'linux'
nvidia-cusolver-cu12>=11.7.5.82 ; sys_platform == 'linux'
nvidia-cusparse-cu12>=12.5.10.65 ; sys_platform == 'linux'
nvidia-nccl-cu12>=2.28.3 ; sys_platform == 'linux'
nvidia-nvjitlink-cu12>=12.9.86 ; sys_platform == 'linux'
nvidia-nvshmem-cu12>=3.4.5 ; sys_platform == 'linux'
oauthlib>=3.3.1
omegaconf>=2.3.0
opentelemetry-api>=1.37.0
opentelemetry-api>=1.38.0
opt-einsum>=3.4.0
optax>=0.2.6
optree>=0.17.0
orbax-checkpoint>=0.11.25
optype>=0.14.0
orbax-checkpoint>=0.11.26
packaging>=25.0
pandas>=2.3.3
parameterized>=0.9.0
pathspec>=0.12.1
pathwaysutils>=0.1.3
pillow>=11.3.0
pillow>=12.0.0
platformdirs>=4.5.0
pluggy>=1.6.0
portpicker>=1.6.0
Expand All @@ -164,37 +172,37 @@ promise>=2.3
propcache>=0.4.1
proto-plus>=1.26.1
protobuf>=5.29.5
psutil>=7.0.0
pyarrow>=21.0.0
psutil>=7.1.0
pyarrow>=22.0.0
pyasn1-modules>=0.4.2
pyasn1>=0.6.1
pycnite>=2024.7.31
pycparser>=2.23 ; implementation_name != 'PyPy' and platform_python_implementation == 'PyPy'
pycryptodomex>=3.23.0
pydantic-core>=2.41.1
pydantic>=2.12.0
pydantic-core>=2.41.4
pydantic>=2.12.3
pydot>=4.0.1
pyelftools>=0.32
pyglove>=0.4.5
pygments>=2.19.2
pyink>=24.10.1
pylint>=3.3.9
pyparsing>=3.2.3
pylint>=4.0.2
pyparsing>=3.2.5
pyproject-hooks>=1.2.0
pytest-xdist>=3.8.0
pytest>=8.4.1
pytest>=8.4.2
python-dateutil>=2.9.0.post0
pytype>=2024.10.11
pytz>=2025.2
pyyaml>=6.0.3
qwix>=0.1.1
regex>=2025.9.18
regex>=2025.10.23
requests-oauthlib>=2.0.0
requests>=2.32.5
rich>=14.1.0
rich>=14.2.0
rsa>=4.9.1
safetensors>=0.6.2
scipy>=1.16.0
scipy-stubs>=1.16.2.4
scipy>=1.16.2
sentencepiece>=0.2.1
seqio>=0.0.20
setuptools>=80.9.0
Expand All @@ -217,32 +225,36 @@ tensorflow-datasets>=4.9.9
tensorflow-metadata>=1.17.2
tensorflow-text>=2.19.0
tensorflow>=2.19.1
tensorstore>=0.1.76
tensorstore>=0.1.78
termcolor>=3.1.0
tiktoken>=0.12.0
tokamax>=0.0.4
tokenizers>=0.22.1
toml>=0.10.2
tomlkit>=0.13.3
toolz>=1.0.0
toolz>=1.1.0
tqdm>=4.67.1
transformer-engine-cu12>=2.8.0
transformer-engine-jax>=2.8.0
transformer-engine>=2.8.0
transformers>=4.57.0
transformers>=4.57.1
treescope>=0.1.10
typing-extensions>=4.14.1
triton>=3.5.0
typeguard>=2.13.3
typing-extensions>=4.15.0
typing-inspection>=0.4.2
tzdata>=2025.2
uritemplate>=4.2.0
urllib3>=2.5.0
uvicorn>=0.37.0
virtualenv>=20.34.0
uvicorn>=0.38.0
virtualenv>=20.35.3
wadler-lindig>=0.1.7
websockets>=15.0.1
werkzeug>=3.1.3
wheel>=0.45.1
wrapt>=1.17.3
wrapt>=2.0.0
xprof>=2.20.7
xxhash>=3.6.0
yarl>=1.22.0
zipp>=3.23.0
zstandard>=0.23.0
zstandard>=0.25.0
5 changes: 2 additions & 3 deletions generated_requirements/tpu-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,12 @@ google-cloud-audit-log>=0.4.0
google-cloud-bigquery>=3.38.0
google-cloud-core>=2.4.3
google-cloud-logging>=3.12.1
google-cloud-mldiagnostics>=0.3.1
google-cloud-monitoring>=2.28.0
google-cloud-resource-manager>=1.15.0
google-cloud-storage>=2.19.0
google-crc32c>=1.7.1
google-genai>=1.46.0
google-jetstream @ https://github.com/AI-Hypercomputer/JetStream/archive/29329e8e73820993f77cfc8efe34eb2a73f5de98.zip
google-pasta>=0.2.0
google-resumable-media>=2.7.2
google-tunix>=0.1.3
Expand Down Expand Up @@ -125,7 +125,6 @@ mdurl>=0.1.2
ml-collections>=1.1.0
ml-dtypes>=0.5.3
ml-goodput-measurement>=0.0.15
mlperf-logging @ https://github.com/mlcommons/logging/archive/38ab22670527888c8eb7825a4ece176fcc36a95d.zip
more-itertools>=10.8.0
mpmath>=1.3.0
msgpack>=1.1.2
Expand Down Expand Up @@ -248,4 +247,4 @@ xprof>=2.20.7
xxhash>=3.6.0
yarl>=1.22.0
zipp>=3.23.0
zstandard>=0.25.0
zstandard>=0.25.0
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ flax
gcsfs
google-api-python-client
google-cloud-aiplatform
google-cloud-mldiagnostics
google-cloud-monitoring
grain[parquet]
huggingface_hub
Expand Down
1 change: 1 addition & 0 deletions requirements_with_jax_ai_image.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
datasets @ https://github.com/huggingface/datasets/archive/6790e138c00b87a1ddc72184f89e7814cf784360.zip
flax>=0.11.0
google-api-python-client
google-cloud-mldiagnostics
google-jetstream @ https://github.com/AI-Hypercomputer/JetStream/archive/29329e8e73820993f77cfc8efe34eb2a73f5de98.zip
grain[parquet]>=0.2.12
jaxtyping
Expand Down
6 changes: 5 additions & 1 deletion src/MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -549,7 +549,7 @@ colocated_python_data_input: False # experimental feature, under testing

# Training loop
steps: 150_001 # If set to -1 then will inherit value from learning_rate_schedule_steps
log_period: 100 # Flushes Tensorboard
log_period: 100 # The frequency of Tensorboard flush, gcs metrics writing, and managed profiler metrics updating.

jax_distributed_initialization_timeout: 300 # This is the default timeout in https://github.com/jax-ml/jax/blob/main/jax/_src/distributed.py
# Note there are two separate initializations - the jax coordination service (aka jax.distributed.initialize) and the backend (e.g. PjRT), the timeout above refers
Expand Down Expand Up @@ -594,6 +594,10 @@ profile_cleanly: True # If set to true, adds a block_until_ready on train state
profile_periodically_period: -1 # If set to a positive integer, profile every profile_periodically_period steps.
# This is useful to debug scenarios where performance is changing.

# Managed profiler settings, which only works when the profiler is "xplane"
managed_profiler: False # Whether to enable the managed profiler
# TODO (b/454720134): Decide if we need this configuration.
managed_profiler_run_group: "" # Used to group multiple runs. If not set, run_name will be used.

# Dump HLO options
dump_hlo: False
Expand Down
2 changes: 1 addition & 1 deletion src/MaxText/layers/attention_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1100,7 +1100,7 @@ def create_sa_config(config, query, key, attn_logits_soft_cap):
block_kv_dkv_compute=min(global_block_kv_dkv_compute, query.shape[2]),
block_q_dq=None if global_use_fused_bwd_kernel else min(global_block_q_dq, query.shape[2]),
block_kv_dq=None if global_use_fused_bwd_kernel else min(global_block_kv_dq, query.shape[2]),
use_fused_bwd_kernel=True, # tokamax only supports fused bwd kernel
use_fused_bwd_kernel=True, # tokamax only supports fused bwd kernel
q_layout=splash_attention_kernel.QKVLayout[global_q_layout],
k_layout=splash_attention_kernel.QKVLayout[global_k_layout],
v_layout=splash_attention_kernel.QKVLayout[global_v_layout],
Expand Down
Loading
Loading