From a4370d7ea07b07526775b6d67af857615f7662f9 Mon Sep 17 00:00:00 2001 From: Sergey Plotnikov Date: Tue, 23 Sep 2025 08:31:08 -0700 Subject: [PATCH 1/5] Enable fine tuning on HPU Signed-off-by: Sergey Plotnikov --- src/instructlab/training/accelerator.py | 10 +++- .../training/batch_loss_manager.py | 7 ++- src/instructlab/training/config.py | 2 + src/instructlab/training/main_ds.py | 56 ++++++++++++++----- src/instructlab/training/model.py | 44 ++++++++++++++- src/instructlab/training/utils.py | 10 +++- 6 files changed, 108 insertions(+), 21 deletions(-) diff --git a/src/instructlab/training/accelerator.py b/src/instructlab/training/accelerator.py index 4baa7c0e..afc6fc88 100644 --- a/src/instructlab/training/accelerator.py +++ b/src/instructlab/training/accelerator.py @@ -43,6 +43,7 @@ def __init__( deepspeed_cpu_offload_optimizer_ratio: Optional[float] = None, fsdp_cpu_offload_params: Optional[bool] = False, fsdp_use_orig_params: Optional[bool] = False, + device: Optional[str] = None, ): self.samples_per_gpu = samples_per_gpu self.save_samples = save_samples @@ -61,6 +62,7 @@ def __init__( self.fsdp_cpu_offload_params = fsdp_cpu_offload_params self.fsdp_use_orig_params = fsdp_use_orig_params self.lr_scheduler = None + self.device_str = device #should be before first use, that happens in self.get_fsdp_config() if self.distributed_framework == DistributedBackend.DEEPSPEED: # Standard accel_args = { @@ -81,6 +83,10 @@ def __init__( "fsdp_plugin": self.get_fsdp_config(), "mixed_precision": "bf16", } + if device == "hpu": + from optimum.habana.accelerate import GaudiAccelerator as TransformersAccel + else: + from accelerate import Accelerator as TransformersAccel self.accelerator = TransformersAccel( **accel_args, ) @@ -159,7 +165,9 @@ def get_fsdp_config(self): use_orig_params=self.fsdp_use_orig_params, # TODO(osilkin): expose switch for fp32 reduction ) - + if self.device_str == "hpu": + fsdp_plugin.use_orig_params=True + fsdp_plugin.sync_module_states=True return fsdp_plugin def get_ds_plugin( diff --git a/src/instructlab/training/batch_loss_manager.py b/src/instructlab/training/batch_loss_manager.py index f0e10a89..25996e8e 100644 --- a/src/instructlab/training/batch_loss_manager.py +++ b/src/instructlab/training/batch_loss_manager.py @@ -46,7 +46,7 @@ class BatchLossManager: - Computing average losses for logging """ - def __init__(self, model, accelerator, world_size: int, local_rank: int): + def __init__(self, model, accelerator, world_size: int, local_rank: int, device: str): """ Initialize the BatchLossManager. @@ -60,7 +60,10 @@ def __init__(self, model, accelerator, world_size: int, local_rank: int): self.accelerator: Accelerator = accelerator self.world_size: int = world_size self.local_rank: int = local_rank - self.torch_device = torch.device("cuda", local_rank) + if device == "hpu": + self.torch_device = torch.device("hpu", local_rank) + else: + self.torch_device = torch.device("cuda", local_rank) def process_batch(self, batch: list[CollatedItem]) -> tuple[BatchMetrics, float]: """ diff --git a/src/instructlab/training/config.py b/src/instructlab/training/config.py index b561a8d6..795144c7 100644 --- a/src/instructlab/training/config.py +++ b/src/instructlab/training/config.py @@ -254,3 +254,5 @@ class TrainingArgs(BaseModel): log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = Field( default="INFO" ) + + device: Optional[str] = None diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index dc8a171e..41d89ec2 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -111,7 +111,7 @@ def train( global_grad_norm = None # Initialize the batch loss manager - batch_loss_manager = BatchLossManager(model, accelerator, world_size, local_rank) + batch_loss_manager = BatchLossManager(model, accelerator, world_size, local_rank, args.device) # Blast through batches for epoch in range(args.current_epoch, args.num_epochs): @@ -150,8 +150,12 @@ def train( elapsed_time = time.time() - start overall_throughput = batch_metrics.total_samples / elapsed_time current_lr = accelerator.lr_scheduler.get_last_lr()[0] - cuda_mem_allocated = torch.cuda.memory_allocated() / (1024**3) - cuda_malloc_retries = torch.cuda.memory_stats()["num_alloc_retries"] + if args.device == "hpu": + mem_allocated = torch.hpu.memory_allocated() / (1024**3) + malloc_retries = 0 + else: + mem_allocated = torch.cuda.memory_allocated() / (1024**3) + malloc_retries = torch.cuda.memory_stats()["num_alloc_retries"] global_grad_norm = ( model.get_global_grad_norm() if hasattr(model, "get_global_grad_norm") @@ -173,8 +177,8 @@ def train( "rank": dist.get_rank(), "overall_throughput": overall_throughput, "lr": current_lr, - "cuda_mem_allocated": cuda_mem_allocated, - "cuda_malloc_retries": cuda_malloc_retries, + ("hpu" if args.device == "hpu" else "cuda") + "_mem_allocated": mem_allocated, + ("hpu" if args.device == "hpu" else "cuda") + "_malloc_retries": malloc_retries, "num_loss_counted_tokens": batch_metrics.num_loss_counted_tokens, "num_tokens_rank0": batch_metrics.total_length, "batch_size": batch_metrics.total_samples, @@ -206,7 +210,8 @@ def train( global_step += 1 if local_rank == 0: inner_pb.update(1) - torch.cuda.empty_cache() + if args.device != "hpu": + torch.cuda.empty_cache() if args.checkpoint_at_epoch: base_logger.debug(f"Saving checkpoint at epoch {epoch}") save_checkpoint( @@ -284,17 +289,22 @@ def main(args): args.model_type = model_conf.model_type #### distributed init ##### - torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) + if args.device == "hpu": + torch.hpu.set_device(int(os.environ["LOCAL_RANK"])) + else: + torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) args.local_rank = int(os.environ["LOCAL_RANK"]) timeout = _get_collective_timeout() - if timeout is not None: - dist.init_process_group(timeout=timeout) - else: - dist.init_process_group() + backend = "hccl" if args.device == "hpu" else None + torch.distributed.init_process_group(backend=backend, timeout=timeout) + args.global_rank = dist.get_rank() - tensor = torch.ByteTensor([False]).cuda() + if args.device == "hpu": + tensor = torch.ByteTensor([False]).to('hpu') + else: + tensor = torch.ByteTensor([False]).cuda() dist.all_reduce(tensor) dist.barrier() @@ -335,6 +345,7 @@ def main(args): flash_enabled=flash_enabled, noise_alpha=args.NEFTune_alpha, lora_quant_bits=args.lora_quant_bits, + device=args.device, ) args.base_model_args = m.base_model_args @@ -410,6 +421,7 @@ def main(args): fsdp_cpu_offload_params=args.cpu_offload_params_fsdp, save_samples=args.save_samples, fsdp_use_orig_params=fsdp_should_use_orig_params, + device=args.device, ) # optimizer needs model that has been prepared by accelerator # and then accelerator needs to be prepared AGAIN once optimizer is initialized @@ -588,6 +600,10 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: if train_args.keep_last_checkpoint_only: command.append("--keep_last_checkpoint_only") + command.append( + f"--device={train_args.device}" + ) + logger.info("Running training command as subprocess: %s", " ".join(command)) process = None interrupt: KeyboardInterrupt | Exception | None = None @@ -789,8 +805,22 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: action="store_true", help="Use Liger kernels for training.", ) + parser.add_argument( + "--device", + type=str, + default=None, + help="PyTorch device to use.", + ) + args = parser.parse_args() - set_random_seed(args.seed) + + if args.device == "hpu": + import habana_frameworks.torch.core as htcore + import habana_frameworks.torch.distributed.hccl + from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi + adapt_transformers_to_gaudi() + + set_random_seed(args.seed, args.device) main(args) """ diff --git a/src/instructlab/training/model.py b/src/instructlab/training/model.py index de863e1d..14174161 100644 --- a/src/instructlab/training/model.py +++ b/src/instructlab/training/model.py @@ -30,7 +30,6 @@ # Third Party from peft import LoraConfig from torch.optim import AdamW -from transformers import Mxfp4Config # pylint: disable=no-name-in-module from transformers import ( AutoModelForCausalLM, BitsAndBytesConfig, @@ -57,17 +56,20 @@ def __init__( flash_enabled: bool = False, lora_config: Optional[LoraConfig] = None, lora_quant_bits: int = 0, + device: Optional[str] = None, ): self.lora_config = lora_config self.noise_alpha = noise_alpha self.tokenizer = tokenizer self.distributed_framework = distributed_framework + self.device = device quant_config = None # check model type & set on the mclasss self.is_gpt_oss = is_gpt_oss(model_path) if self.is_gpt_oss: # Third Party + from transformers import Mxfp4Config # pylint: disable=no-name-in-module quant_config = Mxfp4Config(dequantize=True) # TODO: Add support for 8bit quantization @@ -102,6 +104,19 @@ def __init__( def _post_model_init(self): """Common initialization steps that should happen after model initialization.""" + + if self.device == "hpu" and os.getenv("HPU_ENABLE_TORCH_COMPILE", False): + cache_size_limit = 10*1000 + torch._dynamo.config.cache_size_limit = cache_size_limit + torch._dynamo.config.accumulated_cache_size_limit = 2*cache_size_limit + self.model = torch.compile(self.model, backend="hpu_backend", dynamic=False) + for layer in self.model.model.layers: + layer.compile(backend="hpu_backend", dynamic=False) + if os.environ.get("RANK", '0') == '0': + logger.info( + f"torch.compile has been enabled" + ) + self.reconcile_tokenizer() if self.lora_config: self.model = self.prepare_peft_model() @@ -270,7 +285,11 @@ def _is_causal_lm_model(self) -> bool: bool: True if the model is a causal language model, False otherwise. """ # Third Party - return "ForCausalLM" in self.model.__class__.__name__ + if self.device != "hpu": + class_name = self.model.__class__.__name__ + else: + class_name = self.model._orig_mod.__class__.__name__ if self.model.__class__.__name__ == 'OptimizedModule' else self.model.__class__.__name__ + return "ForCausalLM" in class_name def reconcile_tokenizer(self): if len(self.tokenizer) > self.model.config.vocab_size: @@ -326,6 +345,17 @@ def reconcile_tokenizer(self): ): self.model.config.eos_token_id = self.tokenizer.eos_token_id + if self.device == "hpu": + model = self.model._orig_mod if self.model.__class__.__name__ == 'OptimizedModule' else self.model + class_name = model.__class__.__name__ + + replace_no_split_modules = { + 'GaudiLlamaForCausalLM': ['GaudiLlamaDecoderLayer',] + } + + if class_name in replace_no_split_modules: + model._no_split_modules = replace_no_split_modules[class_name] + if not self._is_causal_lm_model(): raise ValueError( f"Model must be a causal language model, got {type(self.model)}" @@ -386,9 +416,17 @@ def compute_loss( - Dataclass containing the raw pre-scaled losses """ # Forward pass to get logits + hpu_args = {} + if self.device == "hpu": + hpu_args = { + "use_flash_attention":True, + "lazy_mode":False, + } + output = self( **inputs, use_cache=False, + **hpu_args, ) # Manual loss computation with reduction="none" following mini_trainer's exact approach @@ -490,6 +528,7 @@ def __init__( flash_enabled: bool = False, lora_config: Optional[LoraConfig] = None, lora_quant_bits: int = 0, + device: Optional[str] = None, ): super().__init__( model_path=model_path, @@ -499,6 +538,7 @@ def __init__( flash_enabled=flash_enabled, lora_config=lora_config, lora_quant_bits=lora_quant_bits, + device=device, ) self.model = AutoModelForCausalLM.from_pretrained(**self.base_model_args) self._post_model_init() diff --git a/src/instructlab/training/utils.py b/src/instructlab/training/utils.py index 275a4b7e..6ffad25a 100644 --- a/src/instructlab/training/utils.py +++ b/src/instructlab/training/utils.py @@ -318,6 +318,7 @@ def reduce_sum_forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + **(_deprecated_arguments if model.device=="hpu" else {}), ) return_dict = isinstance(output, dict) @@ -775,13 +776,16 @@ def _get_state_dict_patched(model, unwrap=False): accelerator.get_state_dict = get_state_dict_unpatched -def set_random_seed(seed): +def set_random_seed(seed, device: str): if seed is not None: random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - + if device == "hpu": + torch.hpu.manual_seed_all(seed) + else: + torch.cuda.manual_seed_all(seed) + # TODO: move this to also live in the `Model` object def save_checkpoint( From 11a0a9a270c36e2ffe5416b1b60cb0f0bfa08bf8 Mon Sep 17 00:00:00 2001 From: Sergey Plotnikov Date: Thu, 9 Oct 2025 10:35:54 -0700 Subject: [PATCH 2/5] Add HPU specific sampler Signed-off-by: Sergey Plotnikov --- .../training/padded_batch_packer.py | 206 +++++++++++++++++- src/instructlab/training/sampler.py | 41 +++- 2 files changed, 240 insertions(+), 7 deletions(-) diff --git a/src/instructlab/training/padded_batch_packer.py b/src/instructlab/training/padded_batch_packer.py index 3cad913b..194f567c 100644 --- a/src/instructlab/training/padded_batch_packer.py +++ b/src/instructlab/training/padded_batch_packer.py @@ -57,7 +57,7 @@ def _find_optimal_batch_size( for size in range(2, max_sequences + 1): # Check if this batch size fits within token limit padded_tokens = _compute_padded_tokens(lengths, start_idx, start_idx + size) - if padded_tokens > max_tokens: + if padded_tokens > max_tokens: break # Compute padding ratio @@ -286,3 +286,207 @@ def compute_padding_stats(batch_lengths: list[int], batches: list[list[int]]) -> "padding_ratio": padding_ratio, "num_batches": len([b for b in batches if b and b[0] != -1]), } + + +import math +def compute_bucket_size(length: int) -> int: + """ + Bucketing algorithm based on the most significant bit of the sample length. + Finds the most significant bit position 'S', then divides the range + [2^S, 2^(S+1)] into 16 equal buckets of size 2^(S-4). + Limits overhead to at most 1/16 while reducing number of graph recompilations. + """ + msb_pos = length.bit_length() + alignment = (1 << (msb_pos - 4)) if msb_pos >= 4 else 1 + return math.ceil(length / alignment) * alignment + + +def _batch_cost(lengths: list[int]) -> float: + if not (isinstance(lengths, list) and len(lengths) > 0): + raise TypeError(f"wrong input type") + return lengths[0] * len(lengths) + + +def _batch_size_in_tokens(lengths: list[int]) -> float: + if not (isinstance(lengths, list) and len(lengths) > 0): + raise TypeError(f"wrong input type") + return lengths[0] * len(lengths) + + +def _check_batch_cost( + sorted_lengths:list[int], + num_ranks:int64, + max_cost:float, + ) -> list[int]: + if not (isinstance(sorted_lengths, list) and len(sorted_lengths) > 0): + raise TypeError(f"wrong input type") + + bins = [[] for _ in range(num_ranks)] + current_bin = 0 + + for sample_length in sorted_lengths: + while True: + # try to add to current bin + bins[current_bin].append(sample_length) + cost = _batch_cost(bins[current_bin]) + + # go to next sample if current fits + if cost < max_cost: + break + + # bin overflow, remove sample from current bin and try next bin + if len(bins[current_bin]) == 1: + return None + + if current_bin >= num_ranks - 1: + return None + + bins[current_bin].pop() + current_bin += 1 + + bin_sizes = [len(bin) for bin in bins] + return bin_sizes + + +def _distribute_samples_across_ranks( + sorted_lengths: list[int], + max_tokens: int64, + num_ranks: int64, + rank: int64, + ) -> list[int]: + + # compute cost range, from 0 to max possible cost when we put all samples in one bin + lower_bound = 0 + upper_bound = _batch_cost(sorted_lengths) + + # find optimal distribution based on batch cost + prev_bin_sizes = None + epsilon = 1e-6 + while upper_bound - lower_bound > epsilon: + mid = (lower_bound + upper_bound) / 2 + cur_bin_sizes = _check_batch_cost(sorted_lengths, num_ranks, mid) + + if cur_bin_sizes is None: + lower_bound = mid + epsilon + continue + else: + upper_bound = mid + if prev_bin_sizes == cur_bin_sizes: + break + prev_bin_sizes = cur_bin_sizes + + # sanity check + if prev_bin_sizes is not None: + if (any(size <= 0 for size in prev_bin_sizes) or + len(prev_bin_sizes) != num_ranks or + sum(prev_bin_sizes) != len(sorted_lengths)): + raise ValueError("Something went wrong") + + return prev_bin_sizes + + +def _check_batch_size_in_tokens( + sorted_lengths:list[int], + max_tokens: int64, + minibatch_sizes:list[int], + ) -> bool: + + first_sample_idx = 0 + for bs in minibatch_sizes: + minibatch = sorted_lengths[first_sample_idx:first_sample_idx + bs] + if _batch_size_in_tokens(minibatch) >= max_tokens: + return False + first_sample_idx += bs + return True + + +def _compute_sample_indeces_for_current_rank( + lengths: list[int], + minibatches: list[list[int]], + rank: int64, + ) -> list[list[int]]: + + sorted_indices = np.argsort(lengths)[::-1] + minibatch_indices = [] + for minibatch in minibatches: + minibatch_indices.append(sorted_indices[sum(minibatch[:rank]):sum(minibatch[:rank+1])].tolist()) + return minibatch_indices + + +def _compute_num_samples_in_grad_accum_step( + num_samples: int, + grad_accum : int + ) -> list[int]: + + if grad_accum <= 0: + return [] + if grad_accum == 1: + return [num_samples] + + step_size = num_samples // grad_accum + remainder = num_samples % grad_accum + result = [step_size] * grad_accum + result[-1] += remainder + return result + + +def _batch_packing_core_hpu( + lengths: list[int], + max_tokens: int64, + num_ranks: int64, + rank: int64, +) -> list[list[int]]: + + # try different gradient accumulation values + for grad_accum in [1, 2, 4]: + + # break input batch to several gradient accumulation steps + grad_accum_step_sizes = _compute_num_samples_in_grad_accum_step(len(lengths), grad_accum) + + first_sample_idx = 0 + minibatches = [] + for step_size in grad_accum_step_sizes: + step_lengths = lengths[first_sample_idx:first_sample_idx + step_size] + first_sample_idx += step_size + sorted_lengths = sorted(step_lengths, reverse=True) + + # find optimal sample distribution for single step based on computation cost + minibatch_sizes = _distribute_samples_across_ranks(sorted_lengths, max_tokens, num_ranks, rank) + if minibatch_sizes is None: + raise ValueError("Something went wrong") + + # check if found distribution fits in token limit + if not _check_batch_size_in_tokens(sorted_lengths, max_tokens, minibatch_sizes): + # does not fit, increase number of gradient accumulation steps + break + minibatches.append(minibatch_sizes) + + #check if we found suitable sample distribution + if len(minibatches) == grad_accum: + break + + # sanity check + if not ( + len(minibatches) == grad_accum and + all(len(minibatch) == num_ranks for minibatch in minibatches) and + sum(sum(minibatch) for minibatch in minibatches) == len(lengths) + ): + raise ValueError("Could not distribute samples across ranks") + + + # compute indices for current rank + minibatch_indices = _compute_sample_indeces_for_current_rank(lengths, minibatches, rank) + return minibatch_indices + + +def batch_lengths_to_minibatches_hpu( + batch_lengths: list[int], + max_tokens_per_rank: int, + num_ranks: int, + rank: int, +) -> list[list[int]]: + + if not batch_lengths: + return [] + result = _batch_packing_core_hpu( batch_lengths, max_tokens_per_rank, num_ranks, rank) + return result diff --git a/src/instructlab/training/sampler.py b/src/instructlab/training/sampler.py index be2b3f4b..70306077 100644 --- a/src/instructlab/training/sampler.py +++ b/src/instructlab/training/sampler.py @@ -14,6 +14,8 @@ from instructlab.training.batch_packer import batch_lengths_to_minibatches_lpt from instructlab.training.padded_batch_packer import ( batch_lengths_to_minibatches_padded, + batch_lengths_to_minibatches_hpu, + compute_bucket_size, ) from instructlab.training.type_definitions import CollatedItem @@ -24,10 +26,11 @@ class EpochSampler(Sampler): Replaces the naive distributed sampler with reproducible epoch-based shuffling. """ - def __init__(self, len_data: int, seed: int = 67, epoch: int = 0): + def __init__(self, len_data: int, seed: int = 67, epoch: int = 0, lengths : list[int] = None): self.len_data = len_data self.seed = seed self._epoch = epoch + self.lengths = lengths @property def epoch(self) -> int: @@ -37,9 +40,33 @@ def set_epoch(self, epoch: int): self._epoch = epoch def generate_samples(self): - g = torch.Generator() - g.manual_seed(self.seed + self._epoch) - samples = torch.randperm(self.len_data, generator=g).tolist() + if False : # TODO use device != 'hpu' + g = torch.Generator() + g.manual_seed(self.seed + self._epoch) + samples = torch.randperm(self.len_data, generator=g).tolist() + else: + sorted_indices = sorted(range(len(self.lengths)), key=lambda i: self.lengths[i], reverse=True) + + # break list of indices into several chunks + num_chunks = 4 # TODO add command line parameter + chunk_size = len(sorted_indices) // num_chunks + chunks = [] + for i in range(num_chunks): + if i == num_chunks - 1: + chunks.append(sorted_indices[i * chunk_size:]) + else: + chunks.append(sorted_indices[i * chunk_size:(i + 1) * chunk_size]) + + import random + random.seed(self.seed + self._epoch) + random.shuffle(chunks) + for chunk in chunks: + random.shuffle(chunk) + + # flatten the list + from itertools import chain + samples = list(chain.from_iterable(chunks)) + return samples def __iter__(self): @@ -148,6 +175,7 @@ def padded_mb_collate_fn( # Find max length in this batch max_len = max(len(item["input_ids"]) for item in minibatch) + max_len = compute_bucket_size(max_len) #TODO add if device == 'hpu' # Prepare lists for batched tensors padded_input_ids = [] @@ -247,7 +275,8 @@ def __init__( self.batch_packer = batch_lengths_to_minibatches_lpt self.collate_fn = mb_collate_fn else: - self.batch_packer = batch_lengths_to_minibatches_padded + #self.batch_packer = batch_lengths_to_minibatches_padded + self.batch_packer = batch_lengths_to_minibatches_hpu #TODO add if device == 'hpu' # Create a wrapper for padded collate that includes pad_token_id self.collate_fn = ( lambda minibatch, batch_num_loss_counted_tokens: padded_mb_collate_fn( @@ -365,7 +394,7 @@ def get_data_loader( DataLoader configured with appropriate collator based on flash_enabled """ dataset = TokenDataset(data_path) - sampler = EpochSampler(len(dataset), seed=seed) + sampler = EpochSampler(len(dataset), seed=seed, lengths = dataset.get_lengths()) # Create unified collator with appropriate mode collate_fn = MaxTokensPerRankCollator( From 5740d80cd41643bd09ffd2fe60464ea119e3d80c Mon Sep 17 00:00:00 2001 From: Sergey Plotnikov Date: Thu, 16 Oct 2025 08:20:05 -0700 Subject: [PATCH 3/5] Add small batch support and fix odd last batch processing --- .../training/padded_batch_packer.py | 45 ++++++++++++------- 1 file changed, 30 insertions(+), 15 deletions(-) diff --git a/src/instructlab/training/padded_batch_packer.py b/src/instructlab/training/padded_batch_packer.py index 194f567c..8f5837e3 100644 --- a/src/instructlab/training/padded_batch_packer.py +++ b/src/instructlab/training/padded_batch_packer.py @@ -334,9 +334,9 @@ def _check_batch_cost( if cost < max_cost: break - # bin overflow, remove sample from current bin and try next bin + # bin overflow, move last sample to next bin if possible if len(bins[current_bin]) == 1: - return None + break if current_bin >= num_ranks - 1: return None @@ -361,26 +361,25 @@ def _distribute_samples_across_ranks( # find optimal distribution based on batch cost prev_bin_sizes = None - epsilon = 1e-6 + epsilon = 1. # cost has same OOM as batch token count while upper_bound - lower_bound > epsilon: mid = (lower_bound + upper_bound) / 2 cur_bin_sizes = _check_batch_cost(sorted_lengths, num_ranks, mid) if cur_bin_sizes is None: lower_bound = mid + epsilon - continue else: upper_bound = mid - if prev_bin_sizes == cur_bin_sizes: - break prev_bin_sizes = cur_bin_sizes # sanity check if prev_bin_sizes is not None: - if (any(size <= 0 for size in prev_bin_sizes) or - len(prev_bin_sizes) != num_ranks or - sum(prev_bin_sizes) != len(sorted_lengths)): - raise ValueError("Something went wrong") + if (len(prev_bin_sizes) != num_ranks or sum(prev_bin_sizes) != len(sorted_lengths)): + raise ValueError("Something went wrong, we lost samples during distribution across ranks") + + if any(size == 0 for size in prev_bin_sizes): + if any(size > 1 for size in prev_bin_sizes): + raise ValueError("Something went wrong, we put more than one sample per rank for small batch") return prev_bin_sizes @@ -393,10 +392,11 @@ def _check_batch_size_in_tokens( first_sample_idx = 0 for bs in minibatch_sizes: - minibatch = sorted_lengths[first_sample_idx:first_sample_idx + bs] - if _batch_size_in_tokens(minibatch) >= max_tokens: - return False - first_sample_idx += bs + if bs > 0: + minibatch = sorted_lengths[first_sample_idx:first_sample_idx + bs] + if _batch_size_in_tokens(minibatch) >= max_tokens: + return False + first_sample_idx += bs return True @@ -408,8 +408,11 @@ def _compute_sample_indeces_for_current_rank( sorted_indices = np.argsort(lengths)[::-1] minibatch_indices = [] + first_sample_idx = 0 for minibatch in minibatches: - minibatch_indices.append(sorted_indices[sum(minibatch[:rank]):sum(minibatch[:rank+1])].tolist()) + first_sample_idx += sum(minibatch[:rank]) + minibatch_indices.append(sorted_indices[first_sample_idx:first_sample_idx + minibatch[rank]].tolist()) + first_sample_idx += sum(minibatch[rank:]) return minibatch_indices @@ -476,6 +479,18 @@ def _batch_packing_core_hpu( # compute indices for current rank minibatch_indices = _compute_sample_indeces_for_current_rank(lengths, minibatches, rank) + + # sanity check + from itertools import chain + all_indices = list(chain.from_iterable(minibatch_indices)) + if len(all_indices) != len(set(all_indices)): + raise ValueError("Something went wrong, duplicated indices in the list") + + # add one dummy sample to each empty minibatch + for minibatch in minibatch_indices: + if len(minibatch) == 0: + minibatch.append(-1) + return minibatch_indices From 4b03a334d905348e57624bb3a1ca61944eda0e8b Mon Sep 17 00:00:00 2001 From: Sergey Plotnikov Date: Wed, 22 Oct 2025 07:31:07 -0700 Subject: [PATCH 4/5] Add command line parameters to control HPU specific sampler Signed-off-by: Sergey Plotnikov --- .../training/batch_loss_manager.py | 2 +- src/instructlab/training/config.py | 3 ++ src/instructlab/training/main_ds.py | 21 ++++++++++-- src/instructlab/training/model.py | 4 ++- src/instructlab/training/sampler.py | 33 ++++++++++++------- 5 files changed, 46 insertions(+), 17 deletions(-) diff --git a/src/instructlab/training/batch_loss_manager.py b/src/instructlab/training/batch_loss_manager.py index 25996e8e..e0aa14df 100644 --- a/src/instructlab/training/batch_loss_manager.py +++ b/src/instructlab/training/batch_loss_manager.py @@ -61,7 +61,7 @@ def __init__(self, model, accelerator, world_size: int, local_rank: int, device: self.world_size: int = world_size self.local_rank: int = local_rank if device == "hpu": - self.torch_device = torch.device("hpu", local_rank) + self.torch_device = torch.device("hpu") else: self.torch_device = torch.device("cuda", local_rank) diff --git a/src/instructlab/training/config.py b/src/instructlab/training/config.py index 795144c7..ff03b49b 100644 --- a/src/instructlab/training/config.py +++ b/src/instructlab/training/config.py @@ -256,3 +256,6 @@ class TrainingArgs(BaseModel): ) device: Optional[str] = None + torch_compile: bool = False + num_chunks: int = 1 + \ No newline at end of file diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index 41d89ec2..e47ed225 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -383,6 +383,8 @@ def main(args): num_workers=8, # I don't like this but am setting it for consistency flash_enabled=flash_enabled, pad_token_id=pad_token_id, + num_chunks=args.num_chunks, + use_hpu_packer=(args.device=="hpu"), ) if args.local_rank == 0: @@ -600,9 +602,11 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: if train_args.keep_last_checkpoint_only: command.append("--keep_last_checkpoint_only") - command.append( - f"--device={train_args.device}" - ) + command.append(f"--device={train_args.device}") + if train_args.torch_compile: + command.append("--torch-compile") + command.append(f"--num-chunks={train_args.num_chunks}") + logger.info("Running training command as subprocess: %s", " ".join(command)) process = None @@ -811,6 +815,17 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: default=None, help="PyTorch device to use.", ) + parser.add_argument( + '--torch-compile', + action='store_true', default=False, + help='Enable torch.compile, hpu only.' + ) + parser.add_argument( + '--num-chunks', + type=int, + default=1, + help='Number of chunks to split dataset into for sequential training.' + ) args = parser.parse_args() diff --git a/src/instructlab/training/model.py b/src/instructlab/training/model.py index 14174161..97540596 100644 --- a/src/instructlab/training/model.py +++ b/src/instructlab/training/model.py @@ -57,12 +57,14 @@ def __init__( lora_config: Optional[LoraConfig] = None, lora_quant_bits: int = 0, device: Optional[str] = None, + torch_compile: bool = False, ): self.lora_config = lora_config self.noise_alpha = noise_alpha self.tokenizer = tokenizer self.distributed_framework = distributed_framework self.device = device + self.torch_compile = torch_compile quant_config = None # check model type & set on the mclasss @@ -105,7 +107,7 @@ def __init__( def _post_model_init(self): """Common initialization steps that should happen after model initialization.""" - if self.device == "hpu" and os.getenv("HPU_ENABLE_TORCH_COMPILE", False): + if self.device == "hpu" and self.torch_compile: cache_size_limit = 10*1000 torch._dynamo.config.cache_size_limit = cache_size_limit torch._dynamo.config.accumulated_cache_size_limit = 2*cache_size_limit diff --git a/src/instructlab/training/sampler.py b/src/instructlab/training/sampler.py index 70306077..766e9c27 100644 --- a/src/instructlab/training/sampler.py +++ b/src/instructlab/training/sampler.py @@ -26,11 +26,12 @@ class EpochSampler(Sampler): Replaces the naive distributed sampler with reproducible epoch-based shuffling. """ - def __init__(self, len_data: int, seed: int = 67, epoch: int = 0, lengths : list[int] = None): + def __init__(self, len_data: int, seed: int = 67, epoch: int = 0, lengths: list[int] = None, num_chunks: int = 1): self.len_data = len_data self.seed = seed self._epoch = epoch self.lengths = lengths + self.num_chunks = num_chunks @property def epoch(self) -> int: @@ -40,7 +41,7 @@ def set_epoch(self, epoch: int): self._epoch = epoch def generate_samples(self): - if False : # TODO use device != 'hpu' + if self.num_chunks == 1 : g = torch.Generator() g.manual_seed(self.seed + self._epoch) samples = torch.randperm(self.len_data, generator=g).tolist() @@ -48,11 +49,10 @@ def generate_samples(self): sorted_indices = sorted(range(len(self.lengths)), key=lambda i: self.lengths[i], reverse=True) # break list of indices into several chunks - num_chunks = 4 # TODO add command line parameter - chunk_size = len(sorted_indices) // num_chunks + chunk_size = len(sorted_indices) // self.num_chunks chunks = [] - for i in range(num_chunks): - if i == num_chunks - 1: + for i in range(self.num_chunks): + if i == self.num_chunks - 1: chunks.append(sorted_indices[i * chunk_size:]) else: chunks.append(sorted_indices[i * chunk_size:(i + 1) * chunk_size]) @@ -140,7 +140,7 @@ def mb_collate_fn(minibatch, batch_num_loss_counted_tokens) -> CollatedItem: def padded_mb_collate_fn( - minibatch, batch_num_loss_counted_tokens, pad_token_id=0 + minibatch, batch_num_loss_counted_tokens, pad_token_id=0, use_hpu_packer: bool = False, ) -> CollatedItem: """Collates a list of samples into a padded batch for standard attention. @@ -175,7 +175,8 @@ def padded_mb_collate_fn( # Find max length in this batch max_len = max(len(item["input_ids"]) for item in minibatch) - max_len = compute_bucket_size(max_len) #TODO add if device == 'hpu' + if use_hpu_packer: + max_len = compute_bucket_size(max_len) # Prepare lists for batched tensors padded_input_ids = [] @@ -249,6 +250,7 @@ def __init__( dummy_sample=None, flash_enabled: bool = True, pad_token_id: int = 0, + use_hpu_packer: bool = False, ): self.max_tokens_per_rank = max_tokens_per_rank self.flash_enabled = flash_enabled @@ -275,12 +277,14 @@ def __init__( self.batch_packer = batch_lengths_to_minibatches_lpt self.collate_fn = mb_collate_fn else: - #self.batch_packer = batch_lengths_to_minibatches_padded - self.batch_packer = batch_lengths_to_minibatches_hpu #TODO add if device == 'hpu' + if not use_hpu_packer: + self.batch_packer = batch_lengths_to_minibatches_padded + else: + self.batch_packer = batch_lengths_to_minibatches_hpu # Create a wrapper for padded collate that includes pad_token_id self.collate_fn = ( lambda minibatch, batch_num_loss_counted_tokens: padded_mb_collate_fn( - minibatch, batch_num_loss_counted_tokens, pad_token_id + minibatch, batch_num_loss_counted_tokens, pad_token_id, use_hpu_packer ) ) @@ -375,6 +379,8 @@ def get_data_loader( num_workers: int = 0, flash_enabled: bool = True, pad_token_id: int = 0, + num_chunks: int = 1, + use_hpu_packer: bool = False, ): """Create a data loader with epoch-based sampling and batch packing. @@ -389,12 +395,14 @@ def get_data_loader( num_workers: Number of data loading workers flash_enabled: Whether flash attention is enabled (affects collation strategy) pad_token_id: Token ID to use for padding (only used when flash_enabled=False) + num_chunks: Number of chunks to split dataset into for sequential training + use_hpu_packer: Use HPU specific packer Returns: DataLoader configured with appropriate collator based on flash_enabled """ dataset = TokenDataset(data_path) - sampler = EpochSampler(len(dataset), seed=seed, lengths = dataset.get_lengths()) + sampler = EpochSampler(len(dataset), seed=seed, lengths = dataset.get_lengths(), num_chunks = num_chunks) # Create unified collator with appropriate mode collate_fn = MaxTokensPerRankCollator( @@ -404,6 +412,7 @@ def get_data_loader( dummy_sample=dummy_sample, flash_enabled=flash_enabled, pad_token_id=pad_token_id, + use_hpu_packer = use_hpu_packer, ) return DataLoader( From 9fccd07b2606477d2a51d46537c8532a922b30e9 Mon Sep 17 00:00:00 2001 From: Sergey Plotnikov Date: Tue, 28 Oct 2025 08:52:24 -0700 Subject: [PATCH 5/5] Add torch.compile workaround Signed-off-by: Sergey Plotnikov --- .../training/batch_loss_manager.py | 33 ++++++++++++++++++- src/instructlab/training/main_ds.py | 3 +- src/instructlab/training/model.py | 2 ++ 3 files changed, 36 insertions(+), 2 deletions(-) diff --git a/src/instructlab/training/batch_loss_manager.py b/src/instructlab/training/batch_loss_manager.py index e0aa14df..26e5cc83 100644 --- a/src/instructlab/training/batch_loss_manager.py +++ b/src/instructlab/training/batch_loss_manager.py @@ -46,7 +46,7 @@ class BatchLossManager: - Computing average losses for logging """ - def __init__(self, model, accelerator, world_size: int, local_rank: int, device: str): + def __init__(self, model, accelerator, world_size: int, local_rank: int, device: str, save_grads: bool): """ Initialize the BatchLossManager. @@ -64,6 +64,8 @@ def __init__(self, model, accelerator, world_size: int, local_rank: int, device: self.torch_device = torch.device("hpu") else: self.torch_device = torch.device("cuda", local_rank) + self.save_grads = save_grads + self.grad_buffer = {} def process_batch(self, batch: list[CollatedItem]) -> tuple[BatchMetrics, float]: """ @@ -87,6 +89,7 @@ def process_batch(self, batch: list[CollatedItem]) -> tuple[BatchMetrics, float] grad_accum_steps = 0 # process each minibatch + self.grad_buffer = {} for mb in batch: # extract minibatch-specific info micro_batch_size = mb["num_samples"] @@ -105,12 +108,21 @@ def process_batch(self, batch: list[CollatedItem]) -> tuple[BatchMetrics, float] ) self.accelerator.backward(scaled_loss) + # save gradients + if self.save_grads and len(batch) > 1: + self._copy_grads_to_buffer() + self._zero_model_grads() + # accumulate losses grad_accum_steps += 1 accumulated_loss += raw_losses.main_loss if raw_losses.aux_loss is not None: accumulated_aux_loss += raw_losses.aux_loss + # restore gradients + if self.grad_buffer: + self._restore_grads_from_buffer() + # reduce metrics across ranks batch_total_samples, batch_total_length = self._reduce_metrics( batch_total_samples, batch_total_length @@ -189,3 +201,22 @@ def _compute_average_loss( ).item() return avg_loss_across_ranks + + def _copy_grads_to_buffer(self): + for p in self.model.parameters(): + if p.grad is None: + continue + if p not in self.grad_buffer: + self.grad_buffer[p] = p.grad.detach().clone() + else: + self.grad_buffer[p] += p.grad.detach() + + def _restore_grads_from_buffer(self): + for p in self.model.parameters(): + if p in self.grad_buffer: + p.grad = self.grad_buffer[p] + + def _zero_model_grads(self): + for p in self.model.parameters(): + if p.grad is not None: + p.grad = None diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index e47ed225..95a973ff 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -111,7 +111,7 @@ def train( global_grad_norm = None # Initialize the batch loss manager - batch_loss_manager = BatchLossManager(model, accelerator, world_size, local_rank, args.device) + batch_loss_manager = BatchLossManager(model, accelerator, world_size, local_rank, args.device, args.device=="hpu" and args.torch_compile) # Blast through batches for epoch in range(args.current_epoch, args.num_epochs): @@ -346,6 +346,7 @@ def main(args): noise_alpha=args.NEFTune_alpha, lora_quant_bits=args.lora_quant_bits, device=args.device, + torch_compile=args.torch_compile, ) args.base_model_args = m.base_model_args diff --git a/src/instructlab/training/model.py b/src/instructlab/training/model.py index 97540596..44017afa 100644 --- a/src/instructlab/training/model.py +++ b/src/instructlab/training/model.py @@ -531,6 +531,7 @@ def __init__( lora_config: Optional[LoraConfig] = None, lora_quant_bits: int = 0, device: Optional[str] = None, + torch_compile: bool = False, ): super().__init__( model_path=model_path, @@ -541,6 +542,7 @@ def __init__( lora_config=lora_config, lora_quant_bits=lora_quant_bits, device=device, + torch_compile=torch_compile, ) self.model = AutoModelForCausalLM.from_pretrained(**self.base_model_args) self._post_model_init()