2727
2828import numpy as np
2929
30+ from collections .abc import Iterable
3031from jax .experimental import mesh_utils
3132from jax .experimental .serialize_executable import deserialize_and_load
3233from jax .sharding import PartitionSpec as P
34+
3335import jax
3436import jax .numpy as jnp
37+ import jax .tree_util as jtu
3538
3639import optax
3740
@@ -343,25 +346,51 @@ def calculate_prefill_tflops_per_device(num_model_parameters, prefill_length, co
343346 return total_tflops , learnable_weight_tflops , causal_attention_tflops
344347
345348
346- def assert_params_sufficiently_sharded (params , mesh , tolerance ):
347- """Checks whether most params are sharded across sharding axis.
349+ def get_mesh_axes_used_by_tensor_spec (tensor_sharding_spec ):
350+ """
351+ Extracts the set of mesh axis names that a tensor's PartitionSpec uses.
352+
353+ This function inspects a tensor's sharding specification (PartitionSpec) and
354+ identifies which mesh axes are actively used for sharding. If a tensor is not
355+ sharded (i.e., fully replicated), the resulting set will be empty.
356+
357+ Args:
358+ tensor_sharding_spec: The PartitionSpec of a tensor, which defines how it's partitioned across the mesh.
359+ It can be None or contain strings and iterables representing the mesh axes.
360+ all_mesh_axis_names: A collection of all available mesh axis names in the current device mesh.
361+
362+ Returns:
363+ A set of strings, where each string is a mesh axis name used by the
364+ tensor's sharding spec. Returns an empty set for unsharded tensors.
365+ """
366+ # Flatten the sharding spec, as it can contain nested iterables (e.g., ('data', 'mdl')).
367+ tensor_sharding_spec = sum (
368+ [
369+ [axis ] if isinstance (axis , str ) else list (axis ) if isinstance (axis , Iterable ) else []
370+ for axis in tensor_sharding_spec
371+ ],
372+ [],
373+ )
374+ return tensor_sharding_spec
375+
376+
377+ def _get_nontrival_mesh_axes (mesh ):
378+ """
379+ Returns mesh axes from config that are valid and have more than one shard.
348380
349- This function determines whether the majority of parameters are distributed
350- across a specified sharding axes with an acceptable tolerance. It compares the
351- current distribution to a scenario where all parameters are fully sharded
352- across the 'fsdp', 'fsdp_transpose', 'sequence', and 'tensor' axes.
381+ This function identifies which of the predefined potential sharding axes are
382+ actually present in the current device mesh and are configured with a size
383+ greater than one (i.e., are actually sharded).
353384
354385 Args:
355- params: params of the model state
356- mesh: mesh constructed from config
357- tolerance: float between 0.0 and 1.0 representing the allowed percentage of
358- non-sharded parameters.
386+ mesh: The device mesh object, which contains information about the mesh topology, including axis names and their sizes.
387+
359388 Returns:
360- bool: True if the majority of parameters are sufficiently sharded
389+ A set of strings, where each string is a mesh axis name that is both
390+ pre-configured as a target for sharding and has more than one shard in the mesh.
361391 """
362- total_num_params = max_utils .calculate_num_params_from_pytree (params )
363- product_num_devices_for_weight_sharding = 1
364- for axis in [
392+
393+ target_sharding_axes_config = [
365394 "fsdp" ,
366395 "fsdp_transpose" ,
367396 "sequence" ,
@@ -372,19 +401,156 @@ def assert_params_sufficiently_sharded(params, mesh, tolerance):
372401 "tensor_sequence" ,
373402 "stage" ,
374403 "expert" ,
375- ]:
376- product_num_devices_for_weight_sharding *= mesh .shape [axis ]
377- total_num_params_per_chip = max_utils .calculate_total_params_per_chip (params )
378- perfectly_sharded_params_per_chip = total_num_params / product_num_devices_for_weight_sharding
379- assert total_num_params_per_chip >= perfectly_sharded_params_per_chip , (
380- "Number of parameters per chip must not be less than in the ideal sharded "
381- "scenario across `fsdp`, `fsdp_transpose`, `context`, `sequence`, `tensor`, `tensor_transpose`, "
382- "`tensor_sequence`, `stage`, `expert` axes."
383- )
384- unsharded_param_perc = total_num_params_per_chip / perfectly_sharded_params_per_chip - 1
385- assert unsharded_param_perc < tolerance , (
386- f"Number of unsharded parameters exceeds tolerance { tolerance * 100 } % "
387- f"of total parameters with a value of { unsharded_param_perc * 100 } %."
404+ ]
405+
406+ # Filter the target axes to find those that exist in the current mesh
407+ # and have a size greater than 1, meaning they are actually used for sharding.
408+ return {axis for axis in target_sharding_axes_config if axis in mesh .axis_names and mesh .shape [axis ] > 1 }
409+
410+
411+ def _analyze_sharding (params , mesh , valid_target_mesh_axes ):
412+ """
413+ Analyzes parameters to find which are unsharded on any valid mesh axis.
414+
415+ This function iterates through all parameters in a model, checking their
416+ sharding specifications. It identifies parameters that are not sharded along any
417+ of the provided valid target axes (i.e., they are fully replicated across these axes).
418+
419+ Args:
420+ params: A PyTree of model parameters.
421+ mesh: The device mesh object.
422+ valid_target_mesh_axes: A set of mesh axis names that are considered valid targets for sharding.
423+
424+ Returns:
425+ A tuple containing:
426+ - unsharded_params_total_size (int): The total size (number of elements) of all parameters found to be
427+ unsharded on the target axes.
428+ - problematic_tensors_details (list): A list of dictionaries, where each
429+ dictionary contains details about a tensor that is not sharded on any of the target axes.
430+ """
431+ unsharded_params_total_size = 0 # Initialize a counter for the size of unsharded parameters.
432+ problematic_tensors_details = [] # Initialize a list to store details of problematic tensors.
433+
434+ # Get a flattened list of all parameters (leaves) in the PyTree, along with their paths.
435+ all_params_leaves = jtu .tree_leaves_with_path (params )
436+
437+ for path , p_leaf in all_params_leaves : # Iterate over each parameter leaf
438+ param_name_str = jtu .keystr (path ) # Convert the tree path to a readable string
439+
440+ # Check that sharding and spec exist and are valid
441+ sharding = getattr (p_leaf , "sharding" , None )
442+ spec = getattr (sharding , "spec" , None )
443+ assert sharding is not None and spec is not None and isinstance (spec , P ), (
444+ f"Parameter '{ param_name_str } ' is missing a valid '.sharding.spec'."
445+ "Expected 'p_leaf.sharding.spec' to be a non-null 'partitionspec'."
446+ )
447+
448+ current_sharding_spec = p_leaf .sharding .spec # Extract the current tensor's sharding spec
449+ # Identify axes used for sharding
450+ mesh_axes_used = get_mesh_axes_used_by_tensor_spec (current_sharding_spec )
451+ # Check if the parameter is sharded on all the valid target axes.
452+ is_sharded_on_all_target_axis = all (axis in mesh_axes_used for axis in valid_target_mesh_axes )
453+
454+ # If the parameter is not sharded on all of the target axes, it's considered "problematic."
455+ if not is_sharded_on_all_target_axis :
456+ unsharded_params_total_size += p_leaf .size # Add to total unsharded parameter size
457+ unsharded_axes = set (valid_target_mesh_axes ) - set (mesh_axes_used )
458+ # Add detailed info to list of problematic tensors
459+ problematic_tensors_details .append (
460+ {
461+ "name" : param_name_str , # Tensor name
462+ "size" : p_leaf .size , # tensor size
463+ "shape" : p_leaf .shape , # tensor shape
464+ "spec" : str (current_sharding_spec ), # Tensor sharding spec as string
465+ "available_axes" : sorted (list (valid_target_mesh_axes )), # Axes that could be used for sharding
466+ "unsharded_axes" : sorted (list (unsharded_axes )), # Unsharded axes
467+ }
468+ )
469+ # Return the total size of unsharded parameters and the list of problematic tensors.
470+ return unsharded_params_total_size , problematic_tensors_details # Return results
471+
472+
473+ def _raise_if_unsharded_exceeds_tolerance (unsharded_size , total_size , tolerance , problematic_tensors_details ):
474+ """
475+ Raises an AssertionError if the percentage of unsharded parameters exceeds the given tolerance.
476+
477+ This function calculates the proportion of model parameters that are unsharded
478+ and compares it against a specified tolerance. If the tolerance is exceeded,
479+ it constructs and raises a detailed error message.
480+
481+ Args:
482+ unsharded_size: The total size of parameters not sharded on target axes.
483+ total_size: The total size of all parameters in the model.
484+ tolerance: A float (e.g., 0.05 for 5%) representing the maximum allowed percentage of unsharded parameters.
485+ problematic_tensors_details: A list of details about the unsharded tensors,
486+ used to generate an informative error message.
487+
488+ Raises:
489+ AssertionError: If the percentage of unsharded parameters is greater than the tolerance.
490+ """
491+ if total_size <= 0 :
492+ raise ValueError ("Total size must be greater than zero." )
493+
494+ # Calculate the percentage of unsharded parameters.
495+ unsharded_param_perc = unsharded_size / total_size
496+
497+ # If the percentage is over the tolerance, prepare and raise an error.
498+ if unsharded_param_perc > tolerance :
499+ # Sort the problematic tensors by size to show the largest ones first.
500+ problematic_tensors_details .sort (key = lambda x : x ["size" ], reverse = True )
501+
502+ # Begin constructing the error message.
503+ error_msg_lines = [
504+ f"Unsharded parameter percentage ({ unsharded_param_perc :.2%} )" f"exceeds tolerance ({ tolerance :.2%} )."
505+ ]
506+ # Add a header explaining the issue.
507+ error_msg_lines .append (
508+ "The following large tensors are replicated (unsharded) but could be sharded on at "
509+ "least one of the available axes:"
510+ )
511+ # Add details for the top 5 largest problematic tensors.
512+ for detail in problematic_tensors_details [:5 ]: # Show top 5 largest problematic tensors
513+ error_msg_lines .append (
514+ f" - Name: { detail ['name' ]} (Size: { detail ['size' ]} , Shape: { detail ['spec' ]} , Spec: { detail ['spec' ]} ) "
515+ f" is unsharded on axis: { detail ['unsharded_axes' ]} "
516+ f" could be sharded on: { detail ['available_axes' ]} "
517+ )
518+
519+ # Raise the assertion error with the combined, formatted message.
520+ raise AssertionError ("\n " .join (error_msg_lines ))
521+
522+
523+ def assert_params_sufficiently_sharded (params , mesh , tolerance ):
524+ """
525+ Asserts that the total size of replicated parameters is within a given tolerance.
526+
527+ This is the main function that orchestrates the sharding analysis. It determines
528+ the total number of parameters, identifies valid sharding axes, analyzes the
529+ sharding of all parameters, and then raises an error if the amount of
530+ unsharded parameters exceeds the specified tolerance.
531+
532+ Args:
533+ params: A PyTree of model parameters.
534+ mesh: The device mesh object.
535+ tolerance: A float representing the maximum allowed percentage of unsharded parameters.
536+ """
537+ # Calculate the total size of all parameters in the model.
538+ total_num_params = max_utils .calculate_bytes_from_pytree (params )
539+
540+ # Get the set of nontrival mesh axes that can be used for sharding.
541+ valid_target_mesh_axes = _get_nontrival_mesh_axes (mesh )
542+ # If there are no valid axes to shard along, there's nothing to check, so we can exit.
543+ if not valid_target_mesh_axes :
544+ return # Exit early
545+
546+ # Analyze the parameters to find the total size of unsharded parameters
547+ # and get details on which tensors are problematic.
548+ unsharded_params_total_size , problematic_tensors_details = _analyze_sharding (params , mesh , valid_target_mesh_axes )
549+
550+ # Check if the amount of unsharded parameters is within the tolerance and
551+ # raise an exception if it is not.
552+ _raise_if_unsharded_exceeds_tolerance (
553+ unsharded_params_total_size , total_num_params , tolerance , problematic_tensors_details
388554 )
389555
390556
@@ -848,3 +1014,67 @@ def schedule(step):
8481014 boundaries .append (warmup_steps + cos_steps + constant_zero_steps )
8491015
8501016 return optax .join_schedules (pieces , boundaries )
1017+
1018+
1019+ def get_formatted_sharding_annotations (params , mesh = None ):
1020+ """
1021+ Generates a readable string report of sharding annotations for all parameters.
1022+
1023+ This function iterates through a PyTree of model parameters and inspects the
1024+ sharding information attached to each parameter (leaf). It creates a
1025+ human-readable summary that is useful for debugging sharding configurations.
1026+
1027+ Args:
1028+ params: The PyTree of model parameters to inspect.
1029+ mesh: (Optional) The device mesh. If provided, its axis names and shape
1030+ are included in the report for additional context.
1031+
1032+ Returns:
1033+ A single string containing the formatted report of sharding annotations
1034+ for every parameter, with each entry on a new line.
1035+ """
1036+ # Initialize a list to hold the lines of the report, starting with a title.
1037+ annotation_lines = ["Comprehensice Weight Sharding Annotations:" ]
1038+
1039+ # If a mesh object is provided, add its details to the report header.
1040+ if mesh :
1041+ annotation_lines .append (f"Mesh axes: { mesh .axis_names } , Mesh shape: { mesh .shape } " )
1042+ annotation_lines .append ("-" * 30 )
1043+
1044+ # Get a flattened list of all parameters (leaves) and their corresponding paths in the PyTree.
1045+ all_params_leaves = jtu .tree_leaves_with_path (params )
1046+
1047+ # Loop through each parameter leaf in the flattened list.
1048+ for path , p_leaf in all_params_leaves :
1049+ # Convert the parameter's path (a sequence of keys) into a readable string name.
1050+ param_name_str = jtu .keystr (path )
1051+ # Get the shape of the parameter as a string.
1052+ shape_str = str (p_leaf .shape )
1053+ # Set a default description for sharding, in case none is found.
1054+ sharding_desc = "N/A"
1055+
1056+ # Check if the parameter leaf has a 'sharding' attribute.
1057+ if hasattr (p_leaf , "sharding" ):
1058+ # Case 1: Standard JAX sharding with a PartitionSpec.
1059+ if hasattr (p_leaf .sharding , "spec" ) and p_leaf .sharding .spec is not None :
1060+ # The spec is a tuple (PartitionSpec), format it for readability.
1061+ spec_parts = []
1062+ for item in p_leaf .sharding .spec :
1063+ # Represent None as "Replicated" to make it explicit.
1064+ spec_parts .append (str (item ) if item is not None else "Relicated" )
1065+ sharding_desc = f"PartitionSpec({ ', ' .join (spec_parts )} )"
1066+ # Case 2: The parameter is explicitly marked as fully replicated.
1067+ elif hasattr (p_leaf .sharding , "spec" ) and p_leaf .sharding .spec is None :
1068+ sharding_desc = "Fully Replicated (spec is None)"
1069+ # Case 3: A generic fallback if a sharding object exists but has no recognized spec attribute.
1070+ else :
1071+ # Print the string representation of the sharding object itself.
1072+ sharding_desc = str (p_leaf .sharding )
1073+ # Case 4: The parameter has no .sharding attribute at all.
1074+ else :
1075+ sharding_desc = "No .sharding attribute found"
1076+
1077+ # Append the formatted details for the current parameter to our list of lines.
1078+ annotation_lines .append (f" - Param: { param_name_str } \n " f" Shape: { shape_str } \n " f" Sharding: { sharding_desc } " )
1079+ # Join all the collected lines into a single string, separated by newlines.
1080+ return "\n " .join (annotation_lines )
0 commit comments