Skip to content

Commit

Permalink
Multi-node training support (#440)
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelbenayoun authored Feb 29, 2024
1 parent d02965b commit 0916761
Show file tree
Hide file tree
Showing 18 changed files with 973 additions and 357 deletions.
7 changes: 6 additions & 1 deletion optimum/commands/neuron/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
create_custom_cache_repo,
set_custom_cache_repo_name_in_hf_home,
)
from ...neuron.utils.require_utils import requires_torch_neuronx
from ...neuron.utils.runner import ExampleRunner
from ...utils import logging
from ..base import BaseOptimumCLICommand, CommandInfo
Expand Down Expand Up @@ -165,9 +166,13 @@ class SynchronizeRepoCommand(BaseOptimumCLICommand):
@staticmethod
def parse_args(parser: "ArgumentParser"):
parser.add_argument("--repo_id", type=str, default=None, help="The name of the repo to use as remote cache.")
parser.add_argument(
"--cache_dir", type=str, default=None, help="The cache directory that contains the compilation files."
)

@requires_torch_neuronx
def run(self):
synchronize_hub_cache(self.args.repo_id)
synchronize_hub_cache(cache_path=self.args.cache_dir, cache_repo_id=self.args.repo_id)


class LookupRepoCommand(BaseOptimumCLICommand):
Expand Down
25 changes: 21 additions & 4 deletions optimum/neuron/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
patch_within_function,
patched_finfo,
)
from ..utils.misc import args_and_kwargs_to_kwargs_only
from ..utils.misc import args_and_kwargs_to_kwargs_only, is_main_worker
from ..utils.require_utils import requires_neuronx_distributed, requires_torch_xla
from .optimizer import NeuronAcceleratedOptimizer
from .scheduler import NeuronAcceleratedScheduler
Expand Down Expand Up @@ -173,7 +173,11 @@ def __init__(self, *args, mp_plugin: Optional[ModelParallelismPlugin] = None, ze
self.gradient_accumulation_steps = num_steps

def _prepare_data_loader_for_distributed(
self, data_loader: DataLoader, num_replicas: int, rank: int
self,
data_loader: DataLoader,
num_replicas: int,
rank: int,
force_drop_last: bool,
) -> DataLoader:
# TODO: make it more robust, similar to the prepare_data_loader function in `accelerate`.
if isinstance(data_loader.sampler, DistributedSampler):
Expand Down Expand Up @@ -201,22 +205,32 @@ def _prepare_data_loader_for_distributed(
num_workers=data_loader.num_workers,
collate_fn=data_loader.collate_fn,
pin_memory=data_loader.pin_memory,
drop_last=data_loader.drop_last,
drop_last=data_loader.drop_last or force_drop_last,
)

distributed_dataloader._is_accelerate_prepared = True
return distributed_dataloader

def prepare_data_loader(self, data_loader: DataLoader, device_placement: Optional[bool] = None):
force_drop_last = False
if self.state.distributed_type is NeuronDistributedType.MODEL_PARALLELISM:
from neuronx_distributed import parallel_layers

num_replicas = parallel_layers.parallel_state.get_data_parallel_size()
rank = parallel_layers.parallel_state.get_data_parallel_rank()
force_drop_last = parallel_layers.parallel_state.get_pipeline_model_parallel_size() > 1
if is_main_worker() and force_drop_last:
logger.warning(
"Pipeline parallelsim: forcing the dataloader to drop the last incomplete batch because it can "
"cause failure if the last batch size is not divisible by the number of microbatches for the pipeline."
)
else:
num_replicas = xm.xrt_world_size()
rank = xm.get_ordinal()
if self.state.num_processes > 1:
data_loader = self._prepare_data_loader_for_distributed(data_loader, num_replicas=num_replicas, rank=rank)
data_loader = self._prepare_data_loader_for_distributed(
data_loader, num_replicas=num_replicas, rank=rank, force_drop_last=force_drop_last
)
# No need to wrap the dataloader if we are using pipeline parallelism.
if self.state.mp_plugin.pipeline_parallel_size == 1:
data_loader = MpDeviceLoader(data_loader, self.device)
Expand Down Expand Up @@ -471,6 +485,9 @@ def prepare_model(

model = self.patch_model_for_neuron(model)

# We do not want to use the cache here as it would imply more communication that we do not need.
model.config.use_cache = False

if self.distributed_type is NeuronDistributedType.XLA_FSDP:
return self.prepare_model_for_xla_fsdp(
model, device_placement=device_placement, evaluation_mode=evaluation_mode
Expand Down
21 changes: 8 additions & 13 deletions optimum/neuron/accelerate/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,23 +267,11 @@ def __init__(
os.environ.get("ACCELERATE_USE_NEURONX_DISTRIBUTED_TP", "false") == "true"
or os.environ.get("ACCELERATE_USE_NEURONX_DISTRIBUTED_PP", "false") == "true"
):
if not is_neuronx_distributed_available():
raise RuntimeError(
"Model parallelism requires the neuronx_distributed package. You can install it by "
"running: python -m pip install neuronx_distributed --extra-index-url "
"https://pip.repos.neuron.amazonaws.com"
)
if mp_plugin is None:
raise ValueError(
"Could not initialize `neuronx_distributed` model parallelism because no "
"`ModelParallelismPlugin` was provided."
"Could not initialize model parallelism because no `ModelParallelismPlugin` was provided."
)
if mp_plugin.should_parallelize:
if not parallel_state.model_parallel_is_initialized():
parallel_state.initialize_model_parallel(
tensor_model_parallel_size=mp_plugin.tensor_parallel_size,
pipeline_model_parallel_size=mp_plugin.pipeline_parallel_size,
)
self.distributed_type = NeuronDistributedType.MODEL_PARALLELISM
else:
logger.warning(
Expand All @@ -293,6 +281,13 @@ def __init__(
self.mp_plugin = mp_plugin
else:
self.mp_plugin = ModelParallelismPlugin()

if torch.distributed.is_initialized() and not parallel_state.model_parallel_is_initialized():
parallel_state.initialize_model_parallel(
tensor_model_parallel_size=self.mp_plugin.tensor_parallel_size,
pipeline_model_parallel_size=self.mp_plugin.pipeline_parallel_size,
)

if os.environ.get("ACCELERATE_USE_FSDP", "false") == "true":
self.distributed_type = NeuronDistributedType.XLA_FSDP
if self._mixed_precision != "no":
Expand Down
2 changes: 2 additions & 0 deletions optimum/neuron/accelerate/utils/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ class ModelParallelismPlugin:
pipeline_parallel_size: int = 1
pipeline_parallel_num_microbatches: int = 1
pipeline_parallel_use_zero1_optimizer: bool = False
gradient_checkpointing: bool = False
checkpoint_dir: Optional[Union[str, Path]] = None

def __post_init__(self):
Expand Down Expand Up @@ -176,6 +177,7 @@ def parallelize_model(
sequence_parallel_enabled=self.sequence_parallel_enabled,
pipeline_parallel_num_microbatches=self.pipeline_parallel_num_microbatches,
pipeline_parallel_use_zero1_optimizer=self.pipeline_parallel_use_zero1_optimizer,
pipeline_parallel_gradient_checkpointing_enabled=self.gradient_checkpointing,
checkpoint_dir=self.checkpoint_dir,
)
return parallelized_model
Loading

0 comments on commit 0916761

Please sign in to comment.