Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Loss too big when using TP #740

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/source/training_tutorials/sft_lora_finetune_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,13 @@ def training_function(script_args, training_args):
tokenizer = AutoTokenizer.from_pretrained(script_args.model_id)
tokenizer.pad_token = tokenizer.eos_token

# with lazy_load_for_parallelism(tensor_parallel_size=1):
with lazy_load_for_parallelism(tensor_parallel_size=training_args.tensor_parallel_size):
model = AutoModelForCausalLM.from_pretrained(script_args.model_id)

config = LoraConfig(
r=16,
lora_alpha=16,
lora_alpha=32,
lora_dropout=0.05,
target_modules=["q_proj", "gate_proj", "v_proj", "o_proj", "k_proj", "up_proj", "down_proj"],
bias="none",
Expand Down
69 changes: 55 additions & 14 deletions optimum/neuron/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import warnings
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, Union, Dict

import torch
from accelerate import Accelerator
Expand Down Expand Up @@ -81,6 +81,7 @@
xm = None

if is_neuronx_distributed_available():
import neuronx_distributed as nxd
from neuronx_distributed.utils.model_utils import move_model_to_device


Expand All @@ -98,6 +99,7 @@
class NeuronAccelerator(Accelerator):
def __init__(
self,
nxd_config: Dict[str, Any],
*args,
mp_plugin: Optional[ModelParallelismPlugin] = None,
zero_1: bool = False,
Expand Down Expand Up @@ -146,13 +148,17 @@ def patched_is_torch_xla_available(check_is_tpu: bool = False, check_is_gpu: boo

accelerate.state.is_torch_xla_available = patched_is_torch_xla_available

patched_accelerator_state = partial(
NeuronAcceleratorState, mp_plugin=mp_plugin, autocast_backend=autocast_backend
)
with Patcher([("accelerate.accelerator.AcceleratorState", patched_accelerator_state)]):
self.mp_plugin = mp_plugin
self.nxd_config = nxd_config

# patched_accelerator_state = partial(
# NeuronAcceleratorState, mp_plugin=mp_plugin, autocast_backend=autocast_backend
# )
# with Patcher([("accelerate.accelerator.AcceleratorState", patched_accelerator_state)]):
with Patcher([("accelerate.accelerator.AcceleratorState", NeuronAcceleratorState)]):
super().__init__(**full_kwargs)

self.zero_1 = zero_1
self.zero_1 = self.nxd_config["optimizer_config"]["zero_one_enabled"]

if self.autocast_handler is None:
enabled = self.state.mixed_precision == "bf16" and autocast_backend is AutocastBackend.AMP
Expand Down Expand Up @@ -300,17 +306,32 @@ def _prepare_optimizer_for_zero_1(self, optimizer: torch.optim.Optimizer, device
)
return zero_1_optimizer

@requires_neuronx_distributed
@patch_within_function(("accelerate.accelerator.AcceleratedOptimizer", NeuronAcceleratedOptimizer))
def prepare_optimizer(self, optimizer: torch.optim.Optimizer, device_placement: Optional[bool] = None):
if self.distributed_type is NeuronDistributedType.MODEL_PARALLELISM:
optimizer = self._prepare_optimizer_for_mp(optimizer, device_placement=device_placement)
if self.zero_1:
optimizer = self._prepare_optimizer_for_zero_1(optimizer, device_placement=device_placement)
import neuronx_distributed as nxd

#cpu_parameters_to_xla = collections.ChainMap(*self._model_cpu_parameters_to_xla.values())
#xla_parameters, _ = Parallelizer.optimizer_cpu_params_to_xla_params(optimizer, cpu_parameters_to_xla)
#print(xla_parameters)

optimizer = nxd.initialize_parallel_optimizer(
self.nxd_config,
optimizer.__class__,
# xla_parameters,
optimizer.param_groups,
**optimizer.defaults,
)
optimizer.zero_grad()
# if self.distributed_type is NeuronDistributedType.MODEL_PARALLELISM:
# optimizer = self._prepare_optimizer_for_mp(optimizer, device_placement=device_placement)
# if self.zero_1:
# optimizer = self._prepare_optimizer_for_zero_1(optimizer, device_placement=device_placement)
# Edge case: if the optimizer was created lazily outside of the Model Parallelism and/or ZeRO-1 setting, we make
# sure to actually load the proper parameters.
if hasattr(optimizer, "_args_to_recreate"):
args, kwargs = optimizer._args_to_recreate
optimizer = optimizer.__class__(*args, **kwargs)
# if hasattr(optimizer, "_args_to_recreate"):
# args, kwargs = optimizer._args_to_recreate
# optimizer = optimizer.__class__(*args, **kwargs)

return super().prepare_optimizer(optimizer, device_placement=device_placement)

Expand Down Expand Up @@ -449,6 +470,7 @@ def _tie_or_clone_weights_for_mp(self, output_embeddings, input_embeddings):
def prepare_model(
self, model: torch.nn.Module, device_placement: Optional[bool] = None, evaluation_mode: bool = False
):
print("Prepare model")
# If the model was already prepared, we skip.
if model in self._models:
return model
Expand Down Expand Up @@ -500,7 +522,26 @@ def backward(self, loss, **kwargs):
self.scaler.scale(loss).backward(**kwargs)
else:
loss.backward(**kwargs)

# vector_norm = [torch.vector_norm(p.grad, 2) for p in self._models[0].parameters() if p.requires_grad]
# norm = torch.nn.utils.clip_grad_norm_([p for p in self._models[0].parameters() if p.requires_grad], 1.0)
# xm.mark_step()
# print(vector_norm)
# self._models[0].to("cpu")
# print(self._models[0])
# print(norm)
# for n, p in self._models[0].named_parameters():
# if not p.requires_grad or p.grad is None:
# continue
# p = p.grad
# print(f"Gradient of {n}")
# print(f"Min: {p.min():.3f}")
# print(f"Max: {p.max():.3f}")
# print(f"Mean: {p.mean():.3f}")
# print(f"Std: {p.std():.3f}")
# print(f"L1 norm: {p.norm(p=1):.3f}")
# print(f"L2 norm: {p.norm(p=2):.3f}")
# assert 3==2

@contextlib.contextmanager
def autocast(self, cache_enabled: bool = False, autocast_handler: Optional[AutocastKwargs] = None):
if cache_enabled:
Expand Down
17 changes: 11 additions & 6 deletions optimum/neuron/accelerate/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,24 +73,27 @@ def __init__(
self.parameters = [p for group in self.optimizer.param_groups for p in group["params"]]
self.parameter_ids = {id(p) for p in self.parameters}

self.total_grad_norm = []

# TODO: might be needed to override this soon.
def load_state_dict(self, state_dict):
return super().load_state_dict(state_dict)

def prepare_clip_grad_norm(self, parameters, max_norm, norm_type=2):
parameter_ids = {id(p) for p in parameters}
if parameter_ids == self.parameter_ids or isinstance(self.optimizer, ZeroRedundancyOptimizer):
self.clip_grad_norm_to_perform = {"max_norm": max_norm, "norm_type": norm_type}
# if parameter_ids == self.parameter_ids or isinstance(self.optimizer, ZeroRedundancyOptimizer):
# assert 3==2
self.clip_grad_norm_to_perform = {"max_norm": max_norm, "norm_type": norm_type}
return self.total_grad_norm

@requires_neuronx_distributed
def step(self, closure=None):
from neuronx_distributed import parallel_layers
from neuronx_distributed.parallel_layers.grads import bucket_allreduce_gradients

if self.gradient_state.sync_gradients:
# For sequence-parallel, we have to explicitly all-reduce the layernorm gradients.
if self.accelerator_state.distributed_type is NeuronDistributedType.MODEL_PARALLELISM:
allreduce_sequence_parallel_gradients(self.optimizer)
self.optimizer.step()
return

if isinstance(self.optimizer, ZeroRedundancyOptimizer):
if self.clip_grad_norm_to_perform is not None:
Expand All @@ -113,7 +116,9 @@ def step(self, closure=None):
if parallel_layers.parallel_state.get_data_parallel_size() > 1:
bucket_allreduce_gradients(xm._fetch_gradients(self.optimizer))
if self.clip_grad_norm_to_perform is not None:
parallel_layers.clip_grad_norm(self.parameters, **self.clip_grad_norm_to_perform)
self.total_grad_norm.clear()
total_grad_norm = parallel_layers.clip_grad_norm(self.parameters, **self.clip_grad_norm_to_perform)
self.total_grad_norm.append(total_grad_norm)
self.clip_grad_norm_to_perform = None
self.optimizer.step()
elif self.scaler is not None:
Expand Down
30 changes: 19 additions & 11 deletions optimum/neuron/accelerate/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@
import torch_xla.core.xla_model as xm

if is_neuronx_distributed_available():
from neuronx_distributed.parallel_layers import parallel_state
import neuronx_distributed as nxd
# from neuronx_distributed.parallel_layers import parallel_state


logger = logging.get_logger()
Expand Down Expand Up @@ -146,7 +147,7 @@ def __init__(
os.environ["ACCELERATE_USE_AMP"] = "true"
NeuronPartialState(cpu, **kwargs)
self.__dict__.update(NeuronPartialState._shared_state)
self._check_initialized(mixed_precision, cpu, autocast_backend)
self._check_initialized(mixed_precision, cpu) #, autocast_backend)
if not self.initialized:
self.deepspeed_plugin = None
self.ipex_plugin = None
Expand Down Expand Up @@ -200,11 +201,18 @@ def __init__(

self.mp_plugin = mp_plugin

if torch.distributed.is_initialized() and not parallel_state.model_parallel_is_initialized():
parallel_state.initialize_model_parallel(
tensor_model_parallel_size=self.mp_plugin.tensor_parallel_size,
pipeline_model_parallel_size=self.mp_plugin.pipeline_parallel_size,
)
# nxd_config = nxd.neuronx_distributed_config(
# tensor_parallel_size=self.mp_plugin.tensor_parallel_size,
# pipeline_parallel_size=self.mp_plugin.pipeline_parallel_size,
# expert_parallel_size=1, # TODO: add proper argument here once we support MOE

# )

# if torch.distributed.is_initialized() and not parallel_state.model_parallel_is_initialized():
# parallel_state.initialize_model_parallel(
# tensor_model_parallel_size=self.mp_plugin.tensor_parallel_size,
# pipeline_model_parallel_size=self.mp_plugin.pipeline_parallel_size,
# )

if self.distributed_type is DistributedType.NO:
if is_ipex_available():
Expand All @@ -221,16 +229,16 @@ def __init__(

PartialState._shared_state["distributed_type"] = self.distributed_type

def _check_initialized(self, mixed_precision=None, cpu=None, autocast_backend=None):
def _check_initialized(self, mixed_precision=None, cpu=None): # autocast_backend=None):
"Checks if a modification is trying to be made and the `AcceleratorState` has already been initialized"
super()._check_initialized(mixed_precision=mixed_precision, cpu=cpu)
err = (
"AcceleratorState has already been initialized and cannot be changed, restart your runtime completely and "
"pass `{flag}` to `Accelerator()`."
)
if self.initialized:
if autocast_backend is not None and autocast_backend != self.autocast_backend:
raise ValueError(err.format(flag=f"autocast_backend='{autocast_backend}'"))
# if self.initialized:
# if autocast_backend is not None and autocast_backend != self.autocast_backend:
# raise ValueError(err.format(flag=f"autocast_backend='{autocast_backend}'"))

@property
def autocast_backend(self):
Expand Down
Loading
Loading