-
Notifications
You must be signed in to change notification settings - Fork 74
Enable fine tuning on HPU #660
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
a4370d7
11a0a9a
5740d80
4b03a33
9fccd07
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -46,7 +46,7 @@ class BatchLossManager: | |
| - Computing average losses for logging | ||
| """ | ||
|
|
||
| def __init__(self, model, accelerator, world_size: int, local_rank: int): | ||
| def __init__(self, model, accelerator, world_size: int, local_rank: int, device: str, save_grads: bool): | ||
| """ | ||
| Initialize the BatchLossManager. | ||
|
|
||
|
|
@@ -60,7 +60,12 @@ def __init__(self, model, accelerator, world_size: int, local_rank: int): | |
| self.accelerator: Accelerator = accelerator | ||
| self.world_size: int = world_size | ||
| self.local_rank: int = local_rank | ||
| self.torch_device = torch.device("cuda", local_rank) | ||
| if device == "hpu": | ||
| self.torch_device = torch.device("hpu") | ||
| else: | ||
| self.torch_device = torch.device("cuda", local_rank) | ||
| self.save_grads = save_grads | ||
| self.grad_buffer = {} | ||
|
|
||
| def process_batch(self, batch: list[CollatedItem]) -> tuple[BatchMetrics, float]: | ||
| """ | ||
|
|
@@ -84,6 +89,7 @@ def process_batch(self, batch: list[CollatedItem]) -> tuple[BatchMetrics, float] | |
| grad_accum_steps = 0 | ||
|
|
||
| # process each minibatch | ||
| self.grad_buffer = {} | ||
| for mb in batch: | ||
| # extract minibatch-specific info | ||
| micro_batch_size = mb["num_samples"] | ||
|
|
@@ -102,12 +108,21 @@ def process_batch(self, batch: list[CollatedItem]) -> tuple[BatchMetrics, float] | |
| ) | ||
| self.accelerator.backward(scaled_loss) | ||
|
|
||
| # save gradients | ||
| if self.save_grads and len(batch) > 1: | ||
| self._copy_grads_to_buffer() | ||
| self._zero_model_grads() | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @miseshkiwrk How come we need to move the gradients back and forth between the buffer? Does HPU not allow accumulating grads directly in the current objects? If this is the case, could you please leave a block comment somewhere describing the issue and why this is needed? Thanks in advance! |
||
|
|
||
| # accumulate losses | ||
| grad_accum_steps += 1 | ||
| accumulated_loss += raw_losses.main_loss | ||
| if raw_losses.aux_loss is not None: | ||
| accumulated_aux_loss += raw_losses.aux_loss | ||
|
|
||
| # restore gradients | ||
| if self.grad_buffer: | ||
| self._restore_grads_from_buffer() | ||
|
|
||
| # reduce metrics across ranks | ||
| batch_total_samples, batch_total_length = self._reduce_metrics( | ||
| batch_total_samples, batch_total_length | ||
|
|
@@ -186,3 +201,22 @@ def _compute_average_loss( | |
| ).item() | ||
|
|
||
| return avg_loss_across_ranks | ||
|
|
||
| def _copy_grads_to_buffer(self): | ||
| for p in self.model.parameters(): | ||
| if p.grad is None: | ||
| continue | ||
| if p not in self.grad_buffer: | ||
| self.grad_buffer[p] = p.grad.detach().clone() | ||
| else: | ||
| self.grad_buffer[p] += p.grad.detach() | ||
|
|
||
| def _restore_grads_from_buffer(self): | ||
| for p in self.model.parameters(): | ||
| if p in self.grad_buffer: | ||
| p.grad = self.grad_buffer[p] | ||
|
|
||
| def _zero_model_grads(self): | ||
| for p in self.model.parameters(): | ||
| if p.grad is not None: | ||
| p.grad = None | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -254,3 +254,8 @@ class TrainingArgs(BaseModel): | |
| log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = Field( | ||
| default="INFO" | ||
| ) | ||
|
|
||
| device: Optional[str] = None | ||
| torch_compile: bool = False | ||
| num_chunks: int = 1 | ||
|
|
||
|
Comment on lines
+258
to
+261
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think these are good to expose, but what are your thoughts if we adjusted the field names like this?
Would it also make sense to have it be |
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -111,7 +111,7 @@ def train( | |
| global_grad_norm = None | ||
|
|
||
| # Initialize the batch loss manager | ||
| batch_loss_manager = BatchLossManager(model, accelerator, world_size, local_rank) | ||
| batch_loss_manager = BatchLossManager(model, accelerator, world_size, local_rank, args.device, args.device=="hpu" and args.torch_compile) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @splotnikv Since you're setting the |
||
|
|
||
| # Blast through batches | ||
| for epoch in range(args.current_epoch, args.num_epochs): | ||
|
|
@@ -150,8 +150,12 @@ def train( | |
| elapsed_time = time.time() - start | ||
| overall_throughput = batch_metrics.total_samples / elapsed_time | ||
| current_lr = accelerator.lr_scheduler.get_last_lr()[0] | ||
| cuda_mem_allocated = torch.cuda.memory_allocated() / (1024**3) | ||
| cuda_malloc_retries = torch.cuda.memory_stats()["num_alloc_retries"] | ||
| if args.device == "hpu": | ||
| mem_allocated = torch.hpu.memory_allocated() / (1024**3) | ||
| malloc_retries = 0 | ||
| else: | ||
| mem_allocated = torch.cuda.memory_allocated() / (1024**3) | ||
| malloc_retries = torch.cuda.memory_stats()["num_alloc_retries"] | ||
| global_grad_norm = ( | ||
| model.get_global_grad_norm() | ||
| if hasattr(model, "get_global_grad_norm") | ||
|
|
@@ -173,8 +177,8 @@ def train( | |
| "rank": dist.get_rank(), | ||
| "overall_throughput": overall_throughput, | ||
| "lr": current_lr, | ||
| "cuda_mem_allocated": cuda_mem_allocated, | ||
| "cuda_malloc_retries": cuda_malloc_retries, | ||
| ("hpu" if args.device == "hpu" else "cuda") + "_mem_allocated": mem_allocated, | ||
| ("hpu" if args.device == "hpu" else "cuda") + "_malloc_retries": malloc_retries, | ||
| "num_loss_counted_tokens": batch_metrics.num_loss_counted_tokens, | ||
| "num_tokens_rank0": batch_metrics.total_length, | ||
| "batch_size": batch_metrics.total_samples, | ||
|
|
@@ -206,7 +210,8 @@ def train( | |
| global_step += 1 | ||
| if local_rank == 0: | ||
| inner_pb.update(1) | ||
| torch.cuda.empty_cache() | ||
| if args.device != "hpu": | ||
| torch.cuda.empty_cache() | ||
| if args.checkpoint_at_epoch: | ||
| base_logger.debug(f"Saving checkpoint at epoch {epoch}") | ||
| save_checkpoint( | ||
|
|
@@ -284,17 +289,22 @@ def main(args): | |
| args.model_type = model_conf.model_type | ||
|
|
||
| #### distributed init ##### | ||
| torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) | ||
| if args.device == "hpu": | ||
| torch.hpu.set_device(int(os.environ["LOCAL_RANK"])) | ||
| else: | ||
| torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) | ||
| args.local_rank = int(os.environ["LOCAL_RANK"]) | ||
|
|
||
| timeout = _get_collective_timeout() | ||
| if timeout is not None: | ||
| dist.init_process_group(timeout=timeout) | ||
| else: | ||
| dist.init_process_group() | ||
| backend = "hccl" if args.device == "hpu" else None | ||
| torch.distributed.init_process_group(backend=backend, timeout=timeout) | ||
|
|
||
|
|
||
| args.global_rank = dist.get_rank() | ||
| tensor = torch.ByteTensor([False]).cuda() | ||
| if args.device == "hpu": | ||
| tensor = torch.ByteTensor([False]).to('hpu') | ||
| else: | ||
| tensor = torch.ByteTensor([False]).cuda() | ||
| dist.all_reduce(tensor) | ||
| dist.barrier() | ||
|
|
||
|
|
@@ -335,6 +345,8 @@ def main(args): | |
| flash_enabled=flash_enabled, | ||
| noise_alpha=args.NEFTune_alpha, | ||
| lora_quant_bits=args.lora_quant_bits, | ||
| device=args.device, | ||
| torch_compile=args.torch_compile, | ||
| ) | ||
|
|
||
| args.base_model_args = m.base_model_args | ||
|
|
@@ -372,6 +384,8 @@ def main(args): | |
| num_workers=8, # I don't like this but am setting it for consistency | ||
| flash_enabled=flash_enabled, | ||
| pad_token_id=pad_token_id, | ||
| num_chunks=args.num_chunks, | ||
| use_hpu_packer=(args.device=="hpu"), | ||
| ) | ||
|
|
||
| if args.local_rank == 0: | ||
|
|
@@ -410,6 +424,7 @@ def main(args): | |
| fsdp_cpu_offload_params=args.cpu_offload_params_fsdp, | ||
| save_samples=args.save_samples, | ||
| fsdp_use_orig_params=fsdp_should_use_orig_params, | ||
| device=args.device, | ||
| ) | ||
| # optimizer needs model that has been prepared by accelerator | ||
| # and then accelerator needs to be prepared AGAIN once optimizer is initialized | ||
|
|
@@ -588,6 +603,12 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: | |
| if train_args.keep_last_checkpoint_only: | ||
| command.append("--keep_last_checkpoint_only") | ||
|
|
||
| command.append(f"--device={train_args.device}") | ||
| if train_args.torch_compile: | ||
| command.append("--torch-compile") | ||
| command.append(f"--num-chunks={train_args.num_chunks}") | ||
|
|
||
|
|
||
| logger.info("Running training command as subprocess: %s", " ".join(command)) | ||
| process = None | ||
| interrupt: KeyboardInterrupt | Exception | None = None | ||
|
|
@@ -789,8 +810,33 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: | |
| action="store_true", | ||
| help="Use Liger kernels for training.", | ||
| ) | ||
| parser.add_argument( | ||
| "--device", | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the context of PyTorch, we typically assume Since you're looking to allow for switching between HPU and CUDA style training, maybe we could change this to be |
||
| type=str, | ||
| default=None, | ||
| help="PyTorch device to use.", | ||
| ) | ||
| parser.add_argument( | ||
| '--torch-compile', | ||
| action='store_true', default=False, | ||
| help='Enable torch.compile, hpu only.' | ||
| ) | ||
| parser.add_argument( | ||
| '--num-chunks', | ||
| type=int, | ||
| default=1, | ||
| help='Number of chunks to split dataset into for sequential training.' | ||
| ) | ||
|
|
||
| args = parser.parse_args() | ||
| set_random_seed(args.seed) | ||
|
|
||
| if args.device == "hpu": | ||
| import habana_frameworks.torch.core as htcore | ||
| import habana_frameworks.torch.distributed.hccl | ||
| from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi | ||
| adapt_transformers_to_gaudi() | ||
|
|
||
| set_random_seed(args.seed, args.device) | ||
| main(args) | ||
|
|
||
| """ | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -30,7 +30,6 @@ | |
| # Third Party | ||
| from peft import LoraConfig | ||
| from torch.optim import AdamW | ||
| from transformers import Mxfp4Config # pylint: disable=no-name-in-module | ||
| from transformers import ( | ||
| AutoModelForCausalLM, | ||
| BitsAndBytesConfig, | ||
|
|
@@ -57,17 +56,22 @@ def __init__( | |
| flash_enabled: bool = False, | ||
| lora_config: Optional[LoraConfig] = None, | ||
| lora_quant_bits: int = 0, | ||
| device: Optional[str] = None, | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @splotnikv Is |
||
| torch_compile: bool = False, | ||
| ): | ||
| self.lora_config = lora_config | ||
| self.noise_alpha = noise_alpha | ||
| self.tokenizer = tokenizer | ||
| self.distributed_framework = distributed_framework | ||
| self.device = device | ||
| self.torch_compile = torch_compile | ||
| quant_config = None | ||
|
|
||
| # check model type & set on the mclasss | ||
| self.is_gpt_oss = is_gpt_oss(model_path) | ||
| if self.is_gpt_oss: | ||
| # Third Party | ||
| from transformers import Mxfp4Config # pylint: disable=no-name-in-module | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @splotnikv How come this import is being moved from module-level to method-level? |
||
| quant_config = Mxfp4Config(dequantize=True) | ||
|
|
||
| # TODO: Add support for 8bit quantization | ||
|
|
@@ -102,6 +106,19 @@ def __init__( | |
|
|
||
| def _post_model_init(self): | ||
| """Common initialization steps that should happen after model initialization.""" | ||
|
|
||
| if self.device == "hpu" and self.torch_compile: | ||
| cache_size_limit = 10*1000 | ||
| torch._dynamo.config.cache_size_limit = cache_size_limit | ||
| torch._dynamo.config.accumulated_cache_size_limit = 2*cache_size_limit | ||
| self.model = torch.compile(self.model, backend="hpu_backend", dynamic=False) | ||
| for layer in self.model.model.layers: | ||
| layer.compile(backend="hpu_backend", dynamic=False) | ||
| if os.environ.get("RANK", '0') == '0': | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @splotnikv Are you trying to log from the main process on a given node here? The |
||
| logger.info( | ||
| f"torch.compile has been enabled" | ||
| ) | ||
|
|
||
| self.reconcile_tokenizer() | ||
| if self.lora_config: | ||
| self.model = self.prepare_peft_model() | ||
|
|
@@ -270,7 +287,11 @@ def _is_causal_lm_model(self) -> bool: | |
| bool: True if the model is a causal language model, False otherwise. | ||
| """ | ||
| # Third Party | ||
| return "ForCausalLM" in self.model.__class__.__name__ | ||
| if self.device != "hpu": | ||
| class_name = self.model.__class__.__name__ | ||
| else: | ||
| class_name = self.model._orig_mod.__class__.__name__ if self.model.__class__.__name__ == 'OptimizedModule' else self.model.__class__.__name__ | ||
| return "ForCausalLM" in class_name | ||
|
|
||
| def reconcile_tokenizer(self): | ||
| if len(self.tokenizer) > self.model.config.vocab_size: | ||
|
|
@@ -326,6 +347,17 @@ def reconcile_tokenizer(self): | |
| ): | ||
| self.model.config.eos_token_id = self.tokenizer.eos_token_id | ||
|
|
||
| if self.device == "hpu": | ||
| model = self.model._orig_mod if self.model.__class__.__name__ == 'OptimizedModule' else self.model | ||
| class_name = model.__class__.__name__ | ||
|
|
||
| replace_no_split_modules = { | ||
| 'GaudiLlamaForCausalLM': ['GaudiLlamaDecoderLayer',] | ||
| } | ||
|
|
||
| if class_name in replace_no_split_modules: | ||
| model._no_split_modules = replace_no_split_modules[class_name] | ||
|
|
||
| if not self._is_causal_lm_model(): | ||
| raise ValueError( | ||
| f"Model must be a causal language model, got {type(self.model)}" | ||
|
|
@@ -386,9 +418,17 @@ def compute_loss( | |
| - Dataclass containing the raw pre-scaled losses | ||
| """ | ||
| # Forward pass to get logits | ||
| hpu_args = {} | ||
| if self.device == "hpu": | ||
| hpu_args = { | ||
| "use_flash_attention":True, | ||
| "lazy_mode":False, | ||
| } | ||
|
|
||
| output = self( | ||
| **inputs, | ||
| use_cache=False, | ||
| **hpu_args, | ||
| ) | ||
|
|
||
| # Manual loss computation with reduction="none" following mini_trainer's exact approach | ||
|
|
@@ -490,6 +530,8 @@ def __init__( | |
| flash_enabled: bool = False, | ||
| lora_config: Optional[LoraConfig] = None, | ||
| lora_quant_bits: int = 0, | ||
| device: Optional[str] = None, | ||
| torch_compile: bool = False, | ||
| ): | ||
| super().__init__( | ||
| model_path=model_path, | ||
|
|
@@ -499,6 +541,8 @@ def __init__( | |
| flash_enabled=flash_enabled, | ||
| lora_config=lora_config, | ||
| lora_quant_bits=lora_quant_bits, | ||
| device=device, | ||
| torch_compile=torch_compile, | ||
| ) | ||
| self.model = AutoModelForCausalLM.from_pretrained(**self.base_model_args) | ||
| self._post_model_init() | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of just
str, you might want to create a new type that constrains the choices between cuda and hpu: