Skip to content

[Do not review] Upgrade Python to 3.12 and use Ubuntu 24.04 #1280

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 32 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
974b102
Minor fixes
andersensam Jun 17, 2025
d69dbcb
Remove tree util deprecated feature
andersensam Jun 17, 2025
851c619
Merge branch 'apple:main' into jax_0.6.1
andersensam Jun 17, 2025
e6c17ca
Fix formatting for pre-commit
andersensam Jun 17, 2025
971accd
Merge branch 'apple:main' into jax_0.6.1
andersensam Jun 18, 2025
2f4bb74
Update libtpu version
andersensam Jun 18, 2025
27c6915
Update version for libtpu
andersensam Jun 18, 2025
7d67405
Update pyproject.toml
andersensam Jun 18, 2025
a7d252b
Update pyproject.toml
andersensam Jun 18, 2025
c1e4dc6
Update pyproject.toml
andersensam Jun 18, 2025
7330900
Update pyproject.toml
andersensam Jun 18, 2025
3658f55
Update pyproject.toml
andersensam Jun 19, 2025
9eb572d
Merge branch 'apple:main' into python3.12
andersensam Jun 23, 2025
06311e1
Required dep changes
andersensam Jun 23, 2025
c211a20
Merge branch 'apple:main' into python3.12
andersensam Jun 23, 2025
2007c2c
Update package deps
andersensam Jun 25, 2025
7aec095
Merge branch 'apple:main' into python3.12
andersensam Jun 25, 2025
992d3dd
Update requirements
andersensam Jun 26, 2025
087a810
Merge branch 'apple:main' into py3.12_ub24.04
andersensam Jun 26, 2025
3181afe
Merge branch 'apple:main' into py3.12_ub24.04
andersensam Jul 1, 2025
4ef1051
Update Dockerfile for explicit TF install
andersensam Jul 1, 2025
1c6c5f3
Merge branch 'apple:main' into py3.12_ub24.04
andersensam Jul 7, 2025
f34b9e3
Fix tensorboard version
andersensam Jul 8, 2025
0f468a2
Update to account for tfio wheel
andersensam Jul 8, 2025
c01671e
Add modified versions for TF + TFIO and clean up Dockerfile
andersensam Jul 9, 2025
94201c5
Merge branch 'apple:main' into py3.12_ub24.04
andersensam Jul 10, 2025
b63d84a
Merge branch 'apple:main' into py3.12_ub24.04
andersensam Jul 11, 2025
173d9e8
Remove previously required config change. Update pyproject
andersensam Jul 11, 2025
06a74eb
Merge branch 'apple:main' into py3.12_ub24.04
andersensam Jul 14, 2025
06caa05
Merge branch 'apple:main' into py3.12_ub24.04
andersensam Jul 16, 2025
475e6c9
fix ts_impl references
andersensam Jul 16, 2025
636b131
Merge branch 'apple:main' into py3.12_ub24.04
andersensam Jul 17, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# syntax=docker/dockerfile:1

ARG TARGET=base
ARG BASE_IMAGE=ubuntu:22.04
ARG BASE_IMAGE=ubuntu:24.04

FROM ${BASE_IMAGE} AS base

Expand All @@ -15,7 +15,8 @@ RUN echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.
curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | gpg --dearmor -o /usr/share/keyrings/cloud.google.gpg && \
apt-get update -y && \
apt-get install -y apt-transport-https ca-certificates gcc g++ \
git screen ca-certificates google-perftools google-cloud-cli python3.10-venv && apt clean -y
git screen google-perftools google-cloud-cli python3.12-venv && \
apt clean -y

# Setup.
RUN mkdir -p /root
Expand All @@ -26,7 +27,7 @@ COPY pyproject.toml pyproject.toml
RUN mkdir axlearn && touch axlearn/__init__.py
# Setup venv to suppress pip warnings.
ENV VIRTUAL_ENV=/opt/venv
RUN python3 -m venv $VIRTUAL_ENV
RUN python3.12 -m venv $VIRTUAL_ENV
ENV PATH="$VIRTUAL_ENV/bin:$PATH"
# Install dependencies.
RUN pip install --upgrade pip && pip install uv flit && pip cache purge
Expand Down Expand Up @@ -74,7 +75,7 @@ COPY . .

# Dataflow workers can't start properly if the entrypoint is not set
# See: https://cloud.google.com/dataflow/docs/guides/build-container-image#use_a_custom_base_image
COPY --from=apache/beam_python3.10_sdk:2.52.0 /opt/apache/beam /opt/apache/beam
COPY --from=apache/beam_python3.12_sdk:2.52.0 /opt/apache/beam /opt/apache/beam
ENTRYPOINT ["/opt/apache/beam/boot"]

################################################################################
Expand All @@ -85,7 +86,7 @@ FROM base AS tpu

ARG EXTRAS=

ENV UV_FIND_LINKS=https://storage.googleapis.com/jax-releases/libtpu_releases.html
ENV UV_FIND_LINKS="https://storage.googleapis.com/jax-releases/libtpu_releases.html,https://storage.googleapis.com/axlearn-wheels/wheels.html"
# Ensure we install the TPU version, even if building locally.
# Jax will fallback to CPU when run on a machine without TPU.
RUN uv pip install --prerelease=allow .[core,tpu] && uv cache clean
Expand All @@ -99,9 +100,9 @@ COPY . .
FROM base AS gpu

# TODO(markblee): Support extras.
ENV UV_FIND_LINKS=https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
ENV UV_FIND_LINKS="https://storage.googleapis.com/jax-releases/jax_cuda_releases.html,https://storage.googleapis.com/axlearn-wheels/wheels.html"
# Enable the CUDA repository and install the required libraries (libnvrtc.so)
RUN curl -o cuda-keyring_1.1-1_all.deb https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb && \
RUN curl -o cuda-keyring_1.1-1_all.deb https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2404/x86_64/cuda-keyring_1.1-1_all.deb && \
dpkg -i cuda-keyring_1.1-1_all.deb && \
apt-get update && apt-get install -y cuda-libraries-dev-12-8 ibverbs-utils && \
apt clean -y
Expand Down
17 changes: 9 additions & 8 deletions axlearn/common/array_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ async def _async_serialize(
and arr_inp.is_fully_addressable
)
# pylint: disable-next=protected-access
if not serialization._spec_has_metadata(tensorstore_spec):
if not serialization.ts_impl._spec_has_metadata(tensorstore_spec):
# pylint: disable-next=protected-access
tensorstore_spec["metadata"] = serialization._get_metadata(arr_inp)
if "dtype" not in tensorstore_spec:
Expand All @@ -274,14 +274,14 @@ async def _async_serialize(
# does no I/O operation and returns the tensorstore object. For every process other than `0`,
# we open with `assume_metadata=True`.
if jax.process_index() == 0:
await serialization.ts.open(
serialization.ts.Spec(tensorstore_spec),
await serialization.ts_impl.ts.open(
serialization.ts_impl.ts.Spec(tensorstore_spec),
create=True,
open=True,
context=serialization.TS_CONTEXT,
)
t = await serialization.ts.open(
serialization.ts.Spec(tensorstore_spec),
t = await serialization.ts_impl.ts.open(
serialization.ts_impl.ts.Spec(tensorstore_spec),
open=True,
assume_metadata=True,
context=serialization.TS_CONTEXT,
Expand Down Expand Up @@ -417,7 +417,7 @@ async def _async_deserialize(
async def cb(index: array.Index, device: jax.Device):
requested_domain = ts.IndexTransform(input_shape=shape)[index].domain
restricted_domain = t.domain.intersect(requested_domain)
requested_bytes = serialization.estimate_read_memory_footprint(t, restricted_domain)
requested_bytes = serialization.ts_impl.estimate_read_memory_footprint(t, restricted_domain)
# Limit the bytes read for every shard.
await byte_limiter.wait_for_bytes(requested_bytes)
read_ts = t[restricted_domain]
Expand Down Expand Up @@ -478,7 +478,8 @@ async def cb(index: array.Index, device: jax.Device):
await byte_limiter.release_bytes(requested_bytes)
return result

return await serialization.create_async_array_from_callback(shape, in_sharding, cb)
# pylint: disable-next=protected-access
return await serialization.ts_impl._create_async_array_from_callback(shape, in_sharding, cb)


# Reference:
Expand Down Expand Up @@ -560,7 +561,7 @@ def serialize(
# pylint: disable-next=redefined-outer-name
async def _run_serializer():
future_writer = jax.tree.map(
serialization.async_serialize, arrays, tensorstore_specs, commit_futures
serialization.ts_impl.async_serialize, arrays, tensorstore_specs, commit_futures
)
return await asyncio.gather(*future_writer)

Expand Down
14 changes: 8 additions & 6 deletions axlearn/common/array_serialization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def test_async_serialize_d2h_sync(self, sharded):
arr = self._create_partially_replicated_array(sharded)

ts_open_handle: Any = None
old_open = array_serialization.serialization.ts.open
old_open = array_serialization.serialization.ts_impl.ts.open

async def ts_open_patch(*args, **kwargs):
nonlocal ts_open_handle
Expand All @@ -118,7 +118,7 @@ def transfer_to_host_patch(*args, **kwargs):

d2h_future = array_serialization.futures.Future()
with mock.patch(
f"{array_serialization.__name__}.serialization.ts.open",
f"{array_serialization.__name__}.serialization.ts_impl.ts.open",
ts_open_patch,
), get_tensorstore_spec(arr) as spec, mock.patch(
f"{array_serialization.__name__}._transfer_to_host", transfer_to_host_patch
Expand All @@ -144,7 +144,7 @@ def transfer_to_host_patch(*args, **kwargs):
arr_host = jax.device_get(arr)
d2h_future = array_serialization.futures.Future()
with mock.patch(
f"{array_serialization.__name__}.serialization.ts.open",
f"{array_serialization.__name__}.serialization.ts_impl.ts.open",
ts_open_patch,
), get_tensorstore_spec(arr) as spec, mock.patch(
f"{array_serialization.__name__}._transfer_to_host", transfer_to_host_patch
Expand Down Expand Up @@ -178,7 +178,7 @@ async def ts_open_patch(*_, **__):

d2h_future = array_serialization.futures.Future()
with mock.patch(
f"{array_serialization.__name__}.serialization.ts.open",
f"{array_serialization.__name__}.serialization.ts_impl.ts.open",
ts_open_patch,
), get_tensorstore_spec(arr) as spec:
f = _CommitFuture(
Expand Down Expand Up @@ -276,8 +276,10 @@ async def _copy_to_host_patch(shard_infos: list[_ShardInfo]):
mock.patch(
f"{array_serialization.__name__}.serialization._get_metadata", lambda *_: {}
),
mock.patch(f"{array_serialization.__name__}.serialization.ts.open", open_patch),
mock.patch(f"{array_serialization.__name__}.serialization.ts.Spec", mock.MagicMock()),
mock.patch(f"{array_serialization.__name__}.serialization.ts_impl.ts.open", open_patch),
mock.patch(
f"{array_serialization.__name__}.serialization.ts_impl.ts.Spec", mock.MagicMock()
),
):
manager.serialize(arrays, tensorstore_specs, on_commit_callback=lambda: None)
manager.wait_until_finished()
Expand Down
2 changes: 1 addition & 1 deletion axlearn/common/causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def forward(
live_targets = _infer_live_targets(input_batch)
num_targets = live_targets.sum()

logging.info("Module outputs: %s", jax.tree_structure(module_outputs))
logging.info("Module outputs: %s", jax.tree_util.tree_structure(module_outputs))
accumulation = []
for k, v in flatten_items(module_outputs):
if re.fullmatch(regex, k):
Expand Down
2 changes: 1 addition & 1 deletion axlearn/common/gradient_accumulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def _compute_minibatch_size(input_batch: Nested[Tensor], *, steps: int) -> int:
if steps <= 0:
raise ValueError("Accumulation steps need to be a positive integer.")

input_batch_sizes = jax.tree_leaves(jax.tree.map(lambda x: x.shape[0], input_batch))
input_batch_sizes = jax.tree_util.tree_leaves(jax.tree.map(lambda x: x.shape[0], input_batch))

if len(input_batch_sizes) == 0:
raise ValueError("Input batch is empty.")
Expand Down
2 changes: 1 addition & 1 deletion axlearn/common/update_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def real_transform(_):
return new_updates.delta_updates, new_state

def stop_transform(_):
return jax.tree_map(jnp.zeros_like, updates.delta_updates), prev_state
return jax.tree_util.tree_map(jnp.zeros_like, updates.delta_updates), prev_state

# We do the computation regardless of the should_update value, so we could have
# equally used jnp.where() here instead.
Expand Down
12 changes: 6 additions & 6 deletions axlearn/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1907,12 +1907,12 @@ def pytree_children(node: Any) -> Sequence[tuple[KeyEntry, Any]]:
```
"""
# pylint: disable-next=protected-access
registry_with_keypaths = jax._src.tree_util._registry_with_keypaths
# registry_with_keypaths = jax._src.tree_util._registry_with_keypaths

key_handler = registry_with_keypaths.get(type(node))
if key_handler:
key_children, _ = key_handler.flatten_with_keys(node)
return key_children
# key_handler = registry_with_keypaths.get(type(node))
# if key_handler:
# key_children, _ = key_handler.flatten_with_keys(node)
# return key_children

flat = jax.tree_util.default_registry.flatten_one_level(node)
if flat is None:
Expand Down Expand Up @@ -2028,7 +2028,7 @@ def validate_contains_paths(x: Nested[Tensor], paths: Sequence[str]):
except KeyError as e:
raise ValueError(
f"Input is expected to contain '{path}'; "
f"instead, it contains: '{jax.tree_structure(x)}'."
f"instead, it contains: '{jax.tree_util.tree_structure(x)}'."
) from e


Expand Down
4 changes: 3 additions & 1 deletion axlearn/experiments/text/gpt/common_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ def test_mesh_axes(self):
# axis for multiple dims.
for v in cfg.input.input_partitioner.path_rank_to_partition.values():
visited = set()
for axis in jax.tree_leaves(tuple(v)): # Cast to tuple since PartitionSpec is a leaf.
for axis in jax.tree_util.tree_leaves(
tuple(v)
): # Cast to tuple since PartitionSpec is a leaf.
self.assertNotIn(axis, visited)
visited.add(axis)
self.assertGreater(len(visited), 0)
45 changes: 23 additions & 22 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@ name = "axlearn"
version = "0.1.7"
description = "AXLearn"
readme = "README.md"
requires-python = ">=3.10"
requires-python = ">=3.12"

# Production dependencies.
# Minimal requirments for axlearn/common/config.py.
dependencies = [
"attrs>=23.1.0", # We use `type` in `attrs.field`
"numpy==1.26.4", # verified with tensorflow 2.14 RaggedTensor
"numpy>=1.26.4", # verified with tensorflow 2.14 RaggedTensor
]

[project.optional-dependencies]
Expand All @@ -23,20 +23,20 @@ core = [
"absl-py==2.1.0",
"chex==0.1.88",
"importlab==0.8.1", # breaks pytype on 0.8
"jax==0.5.3",
"jaxlib==0.5.3",
"ml-dtypes==0.4.1",
"jax==0.6.2",
"jaxlib==0.6.2",
"ml-dtypes==0.5.1",
"msgpack==1.1.0", # for checkpointing.
"nltk==3.7", # for text preprocessing
"optax==0.1.7", # optimizers (0.1.0 has known bugs).
"portpicker",
"protobuf>=3.20.3",
"tensorboard-plugin-profile==2.15.1",
"tensorboard-plugin-profile==2.19.0",
# This has both x86 and arm64 wheels. Underneath the hood it uses tensorflow-macos since 2.13.
"tensorflow==2.17.1",
"tensorflow==2.19.1",
"tensorflow-datasets>=4.9.2",
"tensorflow-io>=0.37.1", # for tensorflow-2.16. Note that 0.37.0 results in "pure virtual method called".
"tensorflow_text==2.17.0; platform_machine == 'x86_64'", # implied by seqio, but also used directly for text processing
"tensorflow-io==0.37.2", # for tensorflow-2.16. Note that 0.37.0 results in "pure virtual method called".
"tensorflow_text==2.19.0; platform_machine == 'x86_64'", # implied by seqio, but also used directly for text processing
"tensorstore>=0.1.63", # used for supporting GDA checkpoints
"toml", # for config management
"typing-extensions==4.12.2",
Expand Down Expand Up @@ -69,15 +69,15 @@ dev = [
"pylint==2.17.7",
"pytest", # test runner
"pytest-xdist", # pytest plugin for test parallelism
"pytype==2022.4.22", # type checking
"pytype==2024.10.11", # type checking
"scikit-learn==1.5.2", # test-only
# Fix AttributeError: module 'scipy.linalg' has no attribute 'tril' and related scipy import errors.
"scipy==1.12.0",
"sentencepiece != 0.1.92",
"tqdm", # test-only
"timm==0.6.12", # DiT Dependency test-only
"torch>=1.12.1", # test-only
"torchvision==0.16.1", # test-only
"timm==1.0.17", # DiT Dependency test-only
"torch==2.7.1", # test-only
"torchvision==0.22.1", # test-only
"transformers==4.51.3", # test-only
"wandb", # test-only
"wrapt", # implied by tensorflow-datasets, but also used in config tests.
Expand Down Expand Up @@ -108,8 +108,9 @@ gcp = [
# Note: Specify -f https://storage.googleapis.com/jax-releases/libtpu_releases.html during install.
tpu = [
"axlearn[gcp]",
"jax[tpu]==0.5.3", # must be >=0.4.19 for compat with v5p.
"pathwaysutils==0.1.1", # For JAX+Pathways single-controller accelerator coordinator.
"jax[tpu]==0.6.2", # must be >=0.4.19 for compat with v5p.
"pathwaysutils==0.1.1", # For JAX+Pathways single-controller accelerator coordinator.,
"libtpu==0.0.17",
]
# Vertex AI tensorboard. TODO(markblee): Merge with `gcp`.
vertexai_tensorboard = [
Expand All @@ -120,8 +121,8 @@ vertexai_tensorboard = [
"setuptools==65.7.0",
# Pin version to fix Tensorboard uploader TypeError: can only concatenate str (not "NoneType") to str
# https://github.com/googleapis/python-aiplatform/commit/4f982ab254b05fe44a9d2ed959fca2793961b56c
"google-cloud-aiplatform[tensorboard]==1.61.0",
"tensorboard",
"google-cloud-aiplatform[tensorboard]",
"tensorboard==2.19.0",
]
# Dataflow dependencies.
dataflow = [
Expand All @@ -132,9 +133,9 @@ dataflow = [
]
# GPU custom kernel dependency.
gpu = [
"triton==2.1.0",
"jax[cuda12]==0.5.3",
"nvidia-ml-py==12.560.30",
"triton>=2.1.0",
"jax[cuda12]==0.6.2",
"nvidia-ml-py>=12.560.30",
]
# Open API inference.
open_api = [
Expand Down Expand Up @@ -171,11 +172,11 @@ axlearn = "axlearn.cli:main"

[tool.black]
line-length = 100
target-version = ['py38', 'py39']
target-version = ['py310', 'py312']

[tool.ruff]
line-length = 100
target-version = 'py39'
target-version = ['py310', 'py312']

[tool.pytest.ini_options]
addopts = "-rs -s -p no:warnings --junitxml=test-results/testing.xml"
Expand Down