diff --git a/Dockerfile b/Dockerfile index 29db664d3..3a0529d66 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,7 +1,7 @@ # syntax=docker/dockerfile:1 ARG TARGET=base -ARG BASE_IMAGE=ubuntu:22.04 +ARG BASE_IMAGE=ubuntu:24.04 FROM ${BASE_IMAGE} AS base @@ -15,7 +15,8 @@ RUN echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages. curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | gpg --dearmor -o /usr/share/keyrings/cloud.google.gpg && \ apt-get update -y && \ apt-get install -y apt-transport-https ca-certificates gcc g++ \ - git screen ca-certificates google-perftools google-cloud-cli python3.10-venv && apt clean -y + git screen google-perftools google-cloud-cli python3.12-venv && \ + apt clean -y # Setup. RUN mkdir -p /root @@ -26,7 +27,7 @@ COPY pyproject.toml pyproject.toml RUN mkdir axlearn && touch axlearn/__init__.py # Setup venv to suppress pip warnings. ENV VIRTUAL_ENV=/opt/venv -RUN python3 -m venv $VIRTUAL_ENV +RUN python3.12 -m venv $VIRTUAL_ENV ENV PATH="$VIRTUAL_ENV/bin:$PATH" # Install dependencies. RUN pip install --upgrade pip && pip install uv flit && pip cache purge @@ -74,7 +75,7 @@ COPY . . # Dataflow workers can't start properly if the entrypoint is not set # See: https://cloud.google.com/dataflow/docs/guides/build-container-image#use_a_custom_base_image -COPY --from=apache/beam_python3.10_sdk:2.52.0 /opt/apache/beam /opt/apache/beam +COPY --from=apache/beam_python3.12_sdk:2.52.0 /opt/apache/beam /opt/apache/beam ENTRYPOINT ["/opt/apache/beam/boot"] ################################################################################ @@ -85,7 +86,7 @@ FROM base AS tpu ARG EXTRAS= -ENV UV_FIND_LINKS=https://storage.googleapis.com/jax-releases/libtpu_releases.html +ENV UV_FIND_LINKS="https://storage.googleapis.com/jax-releases/libtpu_releases.html,https://storage.googleapis.com/axlearn-wheels/wheels.html" # Ensure we install the TPU version, even if building locally. # Jax will fallback to CPU when run on a machine without TPU. RUN uv pip install --prerelease=allow .[core,tpu] && uv cache clean @@ -99,9 +100,9 @@ COPY . . FROM base AS gpu # TODO(markblee): Support extras. -ENV UV_FIND_LINKS=https://storage.googleapis.com/jax-releases/jax_cuda_releases.html +ENV UV_FIND_LINKS="https://storage.googleapis.com/jax-releases/jax_cuda_releases.html,https://storage.googleapis.com/axlearn-wheels/wheels.html" # Enable the CUDA repository and install the required libraries (libnvrtc.so) -RUN curl -o cuda-keyring_1.1-1_all.deb https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb && \ +RUN curl -o cuda-keyring_1.1-1_all.deb https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2404/x86_64/cuda-keyring_1.1-1_all.deb && \ dpkg -i cuda-keyring_1.1-1_all.deb && \ apt-get update && apt-get install -y cuda-libraries-dev-12-8 ibverbs-utils && \ apt clean -y diff --git a/axlearn/common/array_serialization.py b/axlearn/common/array_serialization.py index 9ba0bbf81..4d99c61cc 100644 --- a/axlearn/common/array_serialization.py +++ b/axlearn/common/array_serialization.py @@ -252,7 +252,7 @@ async def _async_serialize( and arr_inp.is_fully_addressable ) # pylint: disable-next=protected-access - if not serialization._spec_has_metadata(tensorstore_spec): + if not serialization.ts_impl._spec_has_metadata(tensorstore_spec): # pylint: disable-next=protected-access tensorstore_spec["metadata"] = serialization._get_metadata(arr_inp) if "dtype" not in tensorstore_spec: @@ -274,14 +274,14 @@ async def _async_serialize( # does no I/O operation and returns the tensorstore object. For every process other than `0`, # we open with `assume_metadata=True`. if jax.process_index() == 0: - await serialization.ts.open( - serialization.ts.Spec(tensorstore_spec), + await serialization.ts_impl.ts.open( + serialization.ts_impl.ts.Spec(tensorstore_spec), create=True, open=True, context=serialization.TS_CONTEXT, ) - t = await serialization.ts.open( - serialization.ts.Spec(tensorstore_spec), + t = await serialization.ts_impl.ts.open( + serialization.ts_impl.ts.Spec(tensorstore_spec), open=True, assume_metadata=True, context=serialization.TS_CONTEXT, @@ -417,7 +417,7 @@ async def _async_deserialize( async def cb(index: array.Index, device: jax.Device): requested_domain = ts.IndexTransform(input_shape=shape)[index].domain restricted_domain = t.domain.intersect(requested_domain) - requested_bytes = serialization.estimate_read_memory_footprint(t, restricted_domain) + requested_bytes = serialization.ts_impl.estimate_read_memory_footprint(t, restricted_domain) # Limit the bytes read for every shard. await byte_limiter.wait_for_bytes(requested_bytes) read_ts = t[restricted_domain] @@ -478,7 +478,8 @@ async def cb(index: array.Index, device: jax.Device): await byte_limiter.release_bytes(requested_bytes) return result - return await serialization.create_async_array_from_callback(shape, in_sharding, cb) + # pylint: disable-next=protected-access + return await serialization.ts_impl._create_async_array_from_callback(shape, in_sharding, cb) # Reference: @@ -560,7 +561,7 @@ def serialize( # pylint: disable-next=redefined-outer-name async def _run_serializer(): future_writer = jax.tree.map( - serialization.async_serialize, arrays, tensorstore_specs, commit_futures + serialization.ts_impl.async_serialize, arrays, tensorstore_specs, commit_futures ) return await asyncio.gather(*future_writer) diff --git a/axlearn/common/array_serialization_test.py b/axlearn/common/array_serialization_test.py index 5ba234230..29f95649e 100644 --- a/axlearn/common/array_serialization_test.py +++ b/axlearn/common/array_serialization_test.py @@ -96,7 +96,7 @@ def test_async_serialize_d2h_sync(self, sharded): arr = self._create_partially_replicated_array(sharded) ts_open_handle: Any = None - old_open = array_serialization.serialization.ts.open + old_open = array_serialization.serialization.ts_impl.ts.open async def ts_open_patch(*args, **kwargs): nonlocal ts_open_handle @@ -118,7 +118,7 @@ def transfer_to_host_patch(*args, **kwargs): d2h_future = array_serialization.futures.Future() with mock.patch( - f"{array_serialization.__name__}.serialization.ts.open", + f"{array_serialization.__name__}.serialization.ts_impl.ts.open", ts_open_patch, ), get_tensorstore_spec(arr) as spec, mock.patch( f"{array_serialization.__name__}._transfer_to_host", transfer_to_host_patch @@ -144,7 +144,7 @@ def transfer_to_host_patch(*args, **kwargs): arr_host = jax.device_get(arr) d2h_future = array_serialization.futures.Future() with mock.patch( - f"{array_serialization.__name__}.serialization.ts.open", + f"{array_serialization.__name__}.serialization.ts_impl.ts.open", ts_open_patch, ), get_tensorstore_spec(arr) as spec, mock.patch( f"{array_serialization.__name__}._transfer_to_host", transfer_to_host_patch @@ -178,7 +178,7 @@ async def ts_open_patch(*_, **__): d2h_future = array_serialization.futures.Future() with mock.patch( - f"{array_serialization.__name__}.serialization.ts.open", + f"{array_serialization.__name__}.serialization.ts_impl.ts.open", ts_open_patch, ), get_tensorstore_spec(arr) as spec: f = _CommitFuture( @@ -276,8 +276,10 @@ async def _copy_to_host_patch(shard_infos: list[_ShardInfo]): mock.patch( f"{array_serialization.__name__}.serialization._get_metadata", lambda *_: {} ), - mock.patch(f"{array_serialization.__name__}.serialization.ts.open", open_patch), - mock.patch(f"{array_serialization.__name__}.serialization.ts.Spec", mock.MagicMock()), + mock.patch(f"{array_serialization.__name__}.serialization.ts_impl.ts.open", open_patch), + mock.patch( + f"{array_serialization.__name__}.serialization.ts_impl.ts.Spec", mock.MagicMock() + ), ): manager.serialize(arrays, tensorstore_specs, on_commit_callback=lambda: None) manager.wait_until_finished() diff --git a/axlearn/common/causal_lm.py b/axlearn/common/causal_lm.py index c426fb02c..0ae843c32 100644 --- a/axlearn/common/causal_lm.py +++ b/axlearn/common/causal_lm.py @@ -202,7 +202,7 @@ def forward( live_targets = _infer_live_targets(input_batch) num_targets = live_targets.sum() - logging.info("Module outputs: %s", jax.tree_structure(module_outputs)) + logging.info("Module outputs: %s", jax.tree_util.tree_structure(module_outputs)) accumulation = [] for k, v in flatten_items(module_outputs): if re.fullmatch(regex, k): diff --git a/axlearn/common/gradient_accumulation.py b/axlearn/common/gradient_accumulation.py index f70b38b08..6ea48d207 100644 --- a/axlearn/common/gradient_accumulation.py +++ b/axlearn/common/gradient_accumulation.py @@ -33,7 +33,7 @@ def _compute_minibatch_size(input_batch: Nested[Tensor], *, steps: int) -> int: if steps <= 0: raise ValueError("Accumulation steps need to be a positive integer.") - input_batch_sizes = jax.tree_leaves(jax.tree.map(lambda x: x.shape[0], input_batch)) + input_batch_sizes = jax.tree_util.tree_leaves(jax.tree.map(lambda x: x.shape[0], input_batch)) if len(input_batch_sizes) == 0: raise ValueError("Input batch is empty.") diff --git a/axlearn/common/update_transformation.py b/axlearn/common/update_transformation.py index 4c36bacce..3653a39b3 100644 --- a/axlearn/common/update_transformation.py +++ b/axlearn/common/update_transformation.py @@ -186,7 +186,7 @@ def real_transform(_): return new_updates.delta_updates, new_state def stop_transform(_): - return jax.tree_map(jnp.zeros_like, updates.delta_updates), prev_state + return jax.tree_util.tree_map(jnp.zeros_like, updates.delta_updates), prev_state # We do the computation regardless of the should_update value, so we could have # equally used jnp.where() here instead. diff --git a/axlearn/common/utils.py b/axlearn/common/utils.py index 24ac70ad4..d67dba571 100644 --- a/axlearn/common/utils.py +++ b/axlearn/common/utils.py @@ -1907,12 +1907,12 @@ def pytree_children(node: Any) -> Sequence[tuple[KeyEntry, Any]]: ``` """ # pylint: disable-next=protected-access - registry_with_keypaths = jax._src.tree_util._registry_with_keypaths + # registry_with_keypaths = jax._src.tree_util._registry_with_keypaths - key_handler = registry_with_keypaths.get(type(node)) - if key_handler: - key_children, _ = key_handler.flatten_with_keys(node) - return key_children + # key_handler = registry_with_keypaths.get(type(node)) + # if key_handler: + # key_children, _ = key_handler.flatten_with_keys(node) + # return key_children flat = jax.tree_util.default_registry.flatten_one_level(node) if flat is None: @@ -2028,7 +2028,7 @@ def validate_contains_paths(x: Nested[Tensor], paths: Sequence[str]): except KeyError as e: raise ValueError( f"Input is expected to contain '{path}'; " - f"instead, it contains: '{jax.tree_structure(x)}'." + f"instead, it contains: '{jax.tree_util.tree_structure(x)}'." ) from e diff --git a/axlearn/experiments/text/gpt/common_test.py b/axlearn/experiments/text/gpt/common_test.py index b0074f3e7..84a7972eb 100644 --- a/axlearn/experiments/text/gpt/common_test.py +++ b/axlearn/experiments/text/gpt/common_test.py @@ -52,7 +52,9 @@ def test_mesh_axes(self): # axis for multiple dims. for v in cfg.input.input_partitioner.path_rank_to_partition.values(): visited = set() - for axis in jax.tree_leaves(tuple(v)): # Cast to tuple since PartitionSpec is a leaf. + for axis in jax.tree_util.tree_leaves( + tuple(v) + ): # Cast to tuple since PartitionSpec is a leaf. self.assertNotIn(axis, visited) visited.add(axis) self.assertGreater(len(visited), 0) diff --git a/pyproject.toml b/pyproject.toml index 57e415e7e..fa0686dff 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,13 +7,13 @@ name = "axlearn" version = "0.1.7" description = "AXLearn" readme = "README.md" -requires-python = ">=3.10" +requires-python = ">=3.12" # Production dependencies. # Minimal requirments for axlearn/common/config.py. dependencies = [ "attrs>=23.1.0", # We use `type` in `attrs.field` - "numpy==1.26.4", # verified with tensorflow 2.14 RaggedTensor + "numpy>=1.26.4", # verified with tensorflow 2.14 RaggedTensor ] [project.optional-dependencies] @@ -23,20 +23,20 @@ core = [ "absl-py==2.1.0", "chex==0.1.88", "importlab==0.8.1", # breaks pytype on 0.8 - "jax==0.5.3", - "jaxlib==0.5.3", - "ml-dtypes==0.4.1", + "jax==0.6.2", + "jaxlib==0.6.2", + "ml-dtypes==0.5.1", "msgpack==1.1.0", # for checkpointing. "nltk==3.7", # for text preprocessing "optax==0.1.7", # optimizers (0.1.0 has known bugs). "portpicker", "protobuf>=3.20.3", - "tensorboard-plugin-profile==2.15.1", + "tensorboard-plugin-profile==2.19.0", # This has both x86 and arm64 wheels. Underneath the hood it uses tensorflow-macos since 2.13. - "tensorflow==2.17.1", + "tensorflow==2.19.1", "tensorflow-datasets>=4.9.2", - "tensorflow-io>=0.37.1", # for tensorflow-2.16. Note that 0.37.0 results in "pure virtual method called". - "tensorflow_text==2.17.0; platform_machine == 'x86_64'", # implied by seqio, but also used directly for text processing + "tensorflow-io==0.37.2", # for tensorflow-2.16. Note that 0.37.0 results in "pure virtual method called". + "tensorflow_text==2.19.0; platform_machine == 'x86_64'", # implied by seqio, but also used directly for text processing "tensorstore>=0.1.63", # used for supporting GDA checkpoints "toml", # for config management "typing-extensions==4.12.2", @@ -69,15 +69,15 @@ dev = [ "pylint==2.17.7", "pytest", # test runner "pytest-xdist", # pytest plugin for test parallelism - "pytype==2022.4.22", # type checking + "pytype==2024.10.11", # type checking "scikit-learn==1.5.2", # test-only # Fix AttributeError: module 'scipy.linalg' has no attribute 'tril' and related scipy import errors. "scipy==1.12.0", "sentencepiece != 0.1.92", "tqdm", # test-only - "timm==0.6.12", # DiT Dependency test-only - "torch>=1.12.1", # test-only - "torchvision==0.16.1", # test-only + "timm==1.0.17", # DiT Dependency test-only + "torch==2.7.1", # test-only + "torchvision==0.22.1", # test-only "transformers==4.51.3", # test-only "wandb", # test-only "wrapt", # implied by tensorflow-datasets, but also used in config tests. @@ -108,8 +108,9 @@ gcp = [ # Note: Specify -f https://storage.googleapis.com/jax-releases/libtpu_releases.html during install. tpu = [ "axlearn[gcp]", - "jax[tpu]==0.5.3", # must be >=0.4.19 for compat with v5p. - "pathwaysutils==0.1.1", # For JAX+Pathways single-controller accelerator coordinator. + "jax[tpu]==0.6.2", # must be >=0.4.19 for compat with v5p. + "pathwaysutils==0.1.1", # For JAX+Pathways single-controller accelerator coordinator., + "libtpu==0.0.17", ] # Vertex AI tensorboard. TODO(markblee): Merge with `gcp`. vertexai_tensorboard = [ @@ -120,8 +121,8 @@ vertexai_tensorboard = [ "setuptools==65.7.0", # Pin version to fix Tensorboard uploader TypeError: can only concatenate str (not "NoneType") to str # https://github.com/googleapis/python-aiplatform/commit/4f982ab254b05fe44a9d2ed959fca2793961b56c - "google-cloud-aiplatform[tensorboard]==1.61.0", - "tensorboard", + "google-cloud-aiplatform[tensorboard]", + "tensorboard==2.19.0", ] # Dataflow dependencies. dataflow = [ @@ -132,9 +133,9 @@ dataflow = [ ] # GPU custom kernel dependency. gpu = [ - "triton==2.1.0", - "jax[cuda12]==0.5.3", - "nvidia-ml-py==12.560.30", + "triton>=2.1.0", + "jax[cuda12]==0.6.2", + "nvidia-ml-py>=12.560.30", ] # Open API inference. open_api = [ @@ -171,11 +172,11 @@ axlearn = "axlearn.cli:main" [tool.black] line-length = 100 -target-version = ['py38', 'py39'] +target-version = ['py310', 'py312'] [tool.ruff] line-length = 100 -target-version = 'py39' +target-version = ['py310', 'py312'] [tool.pytest.ini_options] addopts = "-rs -s -p no:warnings --junitxml=test-results/testing.xml"