From 5b54984d20a012d4a8f7e38f77ba0c3847575162 Mon Sep 17 00:00:00 2001 From: Steboss Date: Thu, 22 May 2025 10:31:51 +0100 Subject: [PATCH] gather all teh changes to jax.tree --- axlearn/common/causal_lm.py | 2 +- axlearn/common/gradient_accumulation.py | 2 +- axlearn/common/update_transformation.py | 2 +- axlearn/common/utils.py | 2 +- axlearn/experiments/text/gpt/common_test.py | 4 +++- 5 files changed, 7 insertions(+), 5 deletions(-) diff --git a/axlearn/common/causal_lm.py b/axlearn/common/causal_lm.py index ff5ad78cb..db5bfd5fb 100644 --- a/axlearn/common/causal_lm.py +++ b/axlearn/common/causal_lm.py @@ -199,7 +199,7 @@ def forward( live_targets = target_labels >= 0 num_targets = live_targets.sum() - logging.info("Module outputs: %s", jax.tree_structure(module_outputs)) + logging.info("Module outputs: %s", jax.tree_util.tree_structure(module_outputs)) accumulation = [] for k, v in flatten_items(module_outputs): if re.fullmatch(regex, k): diff --git a/axlearn/common/gradient_accumulation.py b/axlearn/common/gradient_accumulation.py index f70b38b08..6ea48d207 100644 --- a/axlearn/common/gradient_accumulation.py +++ b/axlearn/common/gradient_accumulation.py @@ -33,7 +33,7 @@ def _compute_minibatch_size(input_batch: Nested[Tensor], *, steps: int) -> int: if steps <= 0: raise ValueError("Accumulation steps need to be a positive integer.") - input_batch_sizes = jax.tree_leaves(jax.tree.map(lambda x: x.shape[0], input_batch)) + input_batch_sizes = jax.tree_util.tree_leaves(jax.tree.map(lambda x: x.shape[0], input_batch)) if len(input_batch_sizes) == 0: raise ValueError("Input batch is empty.") diff --git a/axlearn/common/update_transformation.py b/axlearn/common/update_transformation.py index 4c36bacce..3653a39b3 100644 --- a/axlearn/common/update_transformation.py +++ b/axlearn/common/update_transformation.py @@ -186,7 +186,7 @@ def real_transform(_): return new_updates.delta_updates, new_state def stop_transform(_): - return jax.tree_map(jnp.zeros_like, updates.delta_updates), prev_state + return jax.tree_util.tree_map(jnp.zeros_like, updates.delta_updates), prev_state # We do the computation regardless of the should_update value, so we could have # equally used jnp.where() here instead. diff --git a/axlearn/common/utils.py b/axlearn/common/utils.py index 5322b4405..8b2507ec9 100644 --- a/axlearn/common/utils.py +++ b/axlearn/common/utils.py @@ -1991,7 +1991,7 @@ def validate_contains_paths(x: Nested[Tensor], paths: Sequence[str]): except KeyError as e: raise ValueError( f"Input is expected to contain '{path}'; " - f"instead, it contains: '{jax.tree_structure(x)}'." + f"instead, it contains: '{jax.tree_util.tree_structure(x)}'." ) from e diff --git a/axlearn/experiments/text/gpt/common_test.py b/axlearn/experiments/text/gpt/common_test.py index b0074f3e7..84a7972eb 100644 --- a/axlearn/experiments/text/gpt/common_test.py +++ b/axlearn/experiments/text/gpt/common_test.py @@ -52,7 +52,9 @@ def test_mesh_axes(self): # axis for multiple dims. for v in cfg.input.input_partitioner.path_rank_to_partition.values(): visited = set() - for axis in jax.tree_leaves(tuple(v)): # Cast to tuple since PartitionSpec is a leaf. + for axis in jax.tree_util.tree_leaves( + tuple(v) + ): # Cast to tuple since PartitionSpec is a leaf. self.assertNotIn(axis, visited) visited.add(axis) self.assertGreater(len(visited), 0)