diff --git a/src/instructlab/training/accelerator.py b/src/instructlab/training/accelerator.py index 4baa7c0e..afc6fc88 100644 --- a/src/instructlab/training/accelerator.py +++ b/src/instructlab/training/accelerator.py @@ -43,6 +43,7 @@ def __init__( deepspeed_cpu_offload_optimizer_ratio: Optional[float] = None, fsdp_cpu_offload_params: Optional[bool] = False, fsdp_use_orig_params: Optional[bool] = False, + device: Optional[str] = None, ): self.samples_per_gpu = samples_per_gpu self.save_samples = save_samples @@ -61,6 +62,7 @@ def __init__( self.fsdp_cpu_offload_params = fsdp_cpu_offload_params self.fsdp_use_orig_params = fsdp_use_orig_params self.lr_scheduler = None + self.device_str = device #should be before first use, that happens in self.get_fsdp_config() if self.distributed_framework == DistributedBackend.DEEPSPEED: # Standard accel_args = { @@ -81,6 +83,10 @@ def __init__( "fsdp_plugin": self.get_fsdp_config(), "mixed_precision": "bf16", } + if device == "hpu": + from optimum.habana.accelerate import GaudiAccelerator as TransformersAccel + else: + from accelerate import Accelerator as TransformersAccel self.accelerator = TransformersAccel( **accel_args, ) @@ -159,7 +165,9 @@ def get_fsdp_config(self): use_orig_params=self.fsdp_use_orig_params, # TODO(osilkin): expose switch for fp32 reduction ) - + if self.device_str == "hpu": + fsdp_plugin.use_orig_params=True + fsdp_plugin.sync_module_states=True return fsdp_plugin def get_ds_plugin( diff --git a/src/instructlab/training/batch_loss_manager.py b/src/instructlab/training/batch_loss_manager.py index f0e10a89..26e5cc83 100644 --- a/src/instructlab/training/batch_loss_manager.py +++ b/src/instructlab/training/batch_loss_manager.py @@ -46,7 +46,7 @@ class BatchLossManager: - Computing average losses for logging """ - def __init__(self, model, accelerator, world_size: int, local_rank: int): + def __init__(self, model, accelerator, world_size: int, local_rank: int, device: str, save_grads: bool): """ Initialize the BatchLossManager. @@ -60,7 +60,12 @@ def __init__(self, model, accelerator, world_size: int, local_rank: int): self.accelerator: Accelerator = accelerator self.world_size: int = world_size self.local_rank: int = local_rank - self.torch_device = torch.device("cuda", local_rank) + if device == "hpu": + self.torch_device = torch.device("hpu") + else: + self.torch_device = torch.device("cuda", local_rank) + self.save_grads = save_grads + self.grad_buffer = {} def process_batch(self, batch: list[CollatedItem]) -> tuple[BatchMetrics, float]: """ @@ -84,6 +89,7 @@ def process_batch(self, batch: list[CollatedItem]) -> tuple[BatchMetrics, float] grad_accum_steps = 0 # process each minibatch + self.grad_buffer = {} for mb in batch: # extract minibatch-specific info micro_batch_size = mb["num_samples"] @@ -102,12 +108,21 @@ def process_batch(self, batch: list[CollatedItem]) -> tuple[BatchMetrics, float] ) self.accelerator.backward(scaled_loss) + # save gradients + if self.save_grads and len(batch) > 1: + self._copy_grads_to_buffer() + self._zero_model_grads() + # accumulate losses grad_accum_steps += 1 accumulated_loss += raw_losses.main_loss if raw_losses.aux_loss is not None: accumulated_aux_loss += raw_losses.aux_loss + # restore gradients + if self.grad_buffer: + self._restore_grads_from_buffer() + # reduce metrics across ranks batch_total_samples, batch_total_length = self._reduce_metrics( batch_total_samples, batch_total_length @@ -186,3 +201,22 @@ def _compute_average_loss( ).item() return avg_loss_across_ranks + + def _copy_grads_to_buffer(self): + for p in self.model.parameters(): + if p.grad is None: + continue + if p not in self.grad_buffer: + self.grad_buffer[p] = p.grad.detach().clone() + else: + self.grad_buffer[p] += p.grad.detach() + + def _restore_grads_from_buffer(self): + for p in self.model.parameters(): + if p in self.grad_buffer: + p.grad = self.grad_buffer[p] + + def _zero_model_grads(self): + for p in self.model.parameters(): + if p.grad is not None: + p.grad = None diff --git a/src/instructlab/training/config.py b/src/instructlab/training/config.py index b561a8d6..ff03b49b 100644 --- a/src/instructlab/training/config.py +++ b/src/instructlab/training/config.py @@ -254,3 +254,8 @@ class TrainingArgs(BaseModel): log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = Field( default="INFO" ) + + device: Optional[str] = None + torch_compile: bool = False + num_chunks: int = 1 + \ No newline at end of file diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index dc8a171e..95a973ff 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -111,7 +111,7 @@ def train( global_grad_norm = None # Initialize the batch loss manager - batch_loss_manager = BatchLossManager(model, accelerator, world_size, local_rank) + batch_loss_manager = BatchLossManager(model, accelerator, world_size, local_rank, args.device, args.device=="hpu" and args.torch_compile) # Blast through batches for epoch in range(args.current_epoch, args.num_epochs): @@ -150,8 +150,12 @@ def train( elapsed_time = time.time() - start overall_throughput = batch_metrics.total_samples / elapsed_time current_lr = accelerator.lr_scheduler.get_last_lr()[0] - cuda_mem_allocated = torch.cuda.memory_allocated() / (1024**3) - cuda_malloc_retries = torch.cuda.memory_stats()["num_alloc_retries"] + if args.device == "hpu": + mem_allocated = torch.hpu.memory_allocated() / (1024**3) + malloc_retries = 0 + else: + mem_allocated = torch.cuda.memory_allocated() / (1024**3) + malloc_retries = torch.cuda.memory_stats()["num_alloc_retries"] global_grad_norm = ( model.get_global_grad_norm() if hasattr(model, "get_global_grad_norm") @@ -173,8 +177,8 @@ def train( "rank": dist.get_rank(), "overall_throughput": overall_throughput, "lr": current_lr, - "cuda_mem_allocated": cuda_mem_allocated, - "cuda_malloc_retries": cuda_malloc_retries, + ("hpu" if args.device == "hpu" else "cuda") + "_mem_allocated": mem_allocated, + ("hpu" if args.device == "hpu" else "cuda") + "_malloc_retries": malloc_retries, "num_loss_counted_tokens": batch_metrics.num_loss_counted_tokens, "num_tokens_rank0": batch_metrics.total_length, "batch_size": batch_metrics.total_samples, @@ -206,7 +210,8 @@ def train( global_step += 1 if local_rank == 0: inner_pb.update(1) - torch.cuda.empty_cache() + if args.device != "hpu": + torch.cuda.empty_cache() if args.checkpoint_at_epoch: base_logger.debug(f"Saving checkpoint at epoch {epoch}") save_checkpoint( @@ -284,17 +289,22 @@ def main(args): args.model_type = model_conf.model_type #### distributed init ##### - torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) + if args.device == "hpu": + torch.hpu.set_device(int(os.environ["LOCAL_RANK"])) + else: + torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) args.local_rank = int(os.environ["LOCAL_RANK"]) timeout = _get_collective_timeout() - if timeout is not None: - dist.init_process_group(timeout=timeout) - else: - dist.init_process_group() + backend = "hccl" if args.device == "hpu" else None + torch.distributed.init_process_group(backend=backend, timeout=timeout) + args.global_rank = dist.get_rank() - tensor = torch.ByteTensor([False]).cuda() + if args.device == "hpu": + tensor = torch.ByteTensor([False]).to('hpu') + else: + tensor = torch.ByteTensor([False]).cuda() dist.all_reduce(tensor) dist.barrier() @@ -335,6 +345,8 @@ def main(args): flash_enabled=flash_enabled, noise_alpha=args.NEFTune_alpha, lora_quant_bits=args.lora_quant_bits, + device=args.device, + torch_compile=args.torch_compile, ) args.base_model_args = m.base_model_args @@ -372,6 +384,8 @@ def main(args): num_workers=8, # I don't like this but am setting it for consistency flash_enabled=flash_enabled, pad_token_id=pad_token_id, + num_chunks=args.num_chunks, + use_hpu_packer=(args.device=="hpu"), ) if args.local_rank == 0: @@ -410,6 +424,7 @@ def main(args): fsdp_cpu_offload_params=args.cpu_offload_params_fsdp, save_samples=args.save_samples, fsdp_use_orig_params=fsdp_should_use_orig_params, + device=args.device, ) # optimizer needs model that has been prepared by accelerator # and then accelerator needs to be prepared AGAIN once optimizer is initialized @@ -588,6 +603,12 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: if train_args.keep_last_checkpoint_only: command.append("--keep_last_checkpoint_only") + command.append(f"--device={train_args.device}") + if train_args.torch_compile: + command.append("--torch-compile") + command.append(f"--num-chunks={train_args.num_chunks}") + + logger.info("Running training command as subprocess: %s", " ".join(command)) process = None interrupt: KeyboardInterrupt | Exception | None = None @@ -789,8 +810,33 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: action="store_true", help="Use Liger kernels for training.", ) + parser.add_argument( + "--device", + type=str, + default=None, + help="PyTorch device to use.", + ) + parser.add_argument( + '--torch-compile', + action='store_true', default=False, + help='Enable torch.compile, hpu only.' + ) + parser.add_argument( + '--num-chunks', + type=int, + default=1, + help='Number of chunks to split dataset into for sequential training.' + ) + args = parser.parse_args() - set_random_seed(args.seed) + + if args.device == "hpu": + import habana_frameworks.torch.core as htcore + import habana_frameworks.torch.distributed.hccl + from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi + adapt_transformers_to_gaudi() + + set_random_seed(args.seed, args.device) main(args) """ diff --git a/src/instructlab/training/model.py b/src/instructlab/training/model.py index de863e1d..44017afa 100644 --- a/src/instructlab/training/model.py +++ b/src/instructlab/training/model.py @@ -30,7 +30,6 @@ # Third Party from peft import LoraConfig from torch.optim import AdamW -from transformers import Mxfp4Config # pylint: disable=no-name-in-module from transformers import ( AutoModelForCausalLM, BitsAndBytesConfig, @@ -57,17 +56,22 @@ def __init__( flash_enabled: bool = False, lora_config: Optional[LoraConfig] = None, lora_quant_bits: int = 0, + device: Optional[str] = None, + torch_compile: bool = False, ): self.lora_config = lora_config self.noise_alpha = noise_alpha self.tokenizer = tokenizer self.distributed_framework = distributed_framework + self.device = device + self.torch_compile = torch_compile quant_config = None # check model type & set on the mclasss self.is_gpt_oss = is_gpt_oss(model_path) if self.is_gpt_oss: # Third Party + from transformers import Mxfp4Config # pylint: disable=no-name-in-module quant_config = Mxfp4Config(dequantize=True) # TODO: Add support for 8bit quantization @@ -102,6 +106,19 @@ def __init__( def _post_model_init(self): """Common initialization steps that should happen after model initialization.""" + + if self.device == "hpu" and self.torch_compile: + cache_size_limit = 10*1000 + torch._dynamo.config.cache_size_limit = cache_size_limit + torch._dynamo.config.accumulated_cache_size_limit = 2*cache_size_limit + self.model = torch.compile(self.model, backend="hpu_backend", dynamic=False) + for layer in self.model.model.layers: + layer.compile(backend="hpu_backend", dynamic=False) + if os.environ.get("RANK", '0') == '0': + logger.info( + f"torch.compile has been enabled" + ) + self.reconcile_tokenizer() if self.lora_config: self.model = self.prepare_peft_model() @@ -270,7 +287,11 @@ def _is_causal_lm_model(self) -> bool: bool: True if the model is a causal language model, False otherwise. """ # Third Party - return "ForCausalLM" in self.model.__class__.__name__ + if self.device != "hpu": + class_name = self.model.__class__.__name__ + else: + class_name = self.model._orig_mod.__class__.__name__ if self.model.__class__.__name__ == 'OptimizedModule' else self.model.__class__.__name__ + return "ForCausalLM" in class_name def reconcile_tokenizer(self): if len(self.tokenizer) > self.model.config.vocab_size: @@ -326,6 +347,17 @@ def reconcile_tokenizer(self): ): self.model.config.eos_token_id = self.tokenizer.eos_token_id + if self.device == "hpu": + model = self.model._orig_mod if self.model.__class__.__name__ == 'OptimizedModule' else self.model + class_name = model.__class__.__name__ + + replace_no_split_modules = { + 'GaudiLlamaForCausalLM': ['GaudiLlamaDecoderLayer',] + } + + if class_name in replace_no_split_modules: + model._no_split_modules = replace_no_split_modules[class_name] + if not self._is_causal_lm_model(): raise ValueError( f"Model must be a causal language model, got {type(self.model)}" @@ -386,9 +418,17 @@ def compute_loss( - Dataclass containing the raw pre-scaled losses """ # Forward pass to get logits + hpu_args = {} + if self.device == "hpu": + hpu_args = { + "use_flash_attention":True, + "lazy_mode":False, + } + output = self( **inputs, use_cache=False, + **hpu_args, ) # Manual loss computation with reduction="none" following mini_trainer's exact approach @@ -490,6 +530,8 @@ def __init__( flash_enabled: bool = False, lora_config: Optional[LoraConfig] = None, lora_quant_bits: int = 0, + device: Optional[str] = None, + torch_compile: bool = False, ): super().__init__( model_path=model_path, @@ -499,6 +541,8 @@ def __init__( flash_enabled=flash_enabled, lora_config=lora_config, lora_quant_bits=lora_quant_bits, + device=device, + torch_compile=torch_compile, ) self.model = AutoModelForCausalLM.from_pretrained(**self.base_model_args) self._post_model_init() diff --git a/src/instructlab/training/padded_batch_packer.py b/src/instructlab/training/padded_batch_packer.py index 3cad913b..8f5837e3 100644 --- a/src/instructlab/training/padded_batch_packer.py +++ b/src/instructlab/training/padded_batch_packer.py @@ -57,7 +57,7 @@ def _find_optimal_batch_size( for size in range(2, max_sequences + 1): # Check if this batch size fits within token limit padded_tokens = _compute_padded_tokens(lengths, start_idx, start_idx + size) - if padded_tokens > max_tokens: + if padded_tokens > max_tokens: break # Compute padding ratio @@ -286,3 +286,222 @@ def compute_padding_stats(batch_lengths: list[int], batches: list[list[int]]) -> "padding_ratio": padding_ratio, "num_batches": len([b for b in batches if b and b[0] != -1]), } + + +import math +def compute_bucket_size(length: int) -> int: + """ + Bucketing algorithm based on the most significant bit of the sample length. + Finds the most significant bit position 'S', then divides the range + [2^S, 2^(S+1)] into 16 equal buckets of size 2^(S-4). + Limits overhead to at most 1/16 while reducing number of graph recompilations. + """ + msb_pos = length.bit_length() + alignment = (1 << (msb_pos - 4)) if msb_pos >= 4 else 1 + return math.ceil(length / alignment) * alignment + + +def _batch_cost(lengths: list[int]) -> float: + if not (isinstance(lengths, list) and len(lengths) > 0): + raise TypeError(f"wrong input type") + return lengths[0] * len(lengths) + + +def _batch_size_in_tokens(lengths: list[int]) -> float: + if not (isinstance(lengths, list) and len(lengths) > 0): + raise TypeError(f"wrong input type") + return lengths[0] * len(lengths) + + +def _check_batch_cost( + sorted_lengths:list[int], + num_ranks:int64, + max_cost:float, + ) -> list[int]: + if not (isinstance(sorted_lengths, list) and len(sorted_lengths) > 0): + raise TypeError(f"wrong input type") + + bins = [[] for _ in range(num_ranks)] + current_bin = 0 + + for sample_length in sorted_lengths: + while True: + # try to add to current bin + bins[current_bin].append(sample_length) + cost = _batch_cost(bins[current_bin]) + + # go to next sample if current fits + if cost < max_cost: + break + + # bin overflow, move last sample to next bin if possible + if len(bins[current_bin]) == 1: + break + + if current_bin >= num_ranks - 1: + return None + + bins[current_bin].pop() + current_bin += 1 + + bin_sizes = [len(bin) for bin in bins] + return bin_sizes + + +def _distribute_samples_across_ranks( + sorted_lengths: list[int], + max_tokens: int64, + num_ranks: int64, + rank: int64, + ) -> list[int]: + + # compute cost range, from 0 to max possible cost when we put all samples in one bin + lower_bound = 0 + upper_bound = _batch_cost(sorted_lengths) + + # find optimal distribution based on batch cost + prev_bin_sizes = None + epsilon = 1. # cost has same OOM as batch token count + while upper_bound - lower_bound > epsilon: + mid = (lower_bound + upper_bound) / 2 + cur_bin_sizes = _check_batch_cost(sorted_lengths, num_ranks, mid) + + if cur_bin_sizes is None: + lower_bound = mid + epsilon + else: + upper_bound = mid + prev_bin_sizes = cur_bin_sizes + + # sanity check + if prev_bin_sizes is not None: + if (len(prev_bin_sizes) != num_ranks or sum(prev_bin_sizes) != len(sorted_lengths)): + raise ValueError("Something went wrong, we lost samples during distribution across ranks") + + if any(size == 0 for size in prev_bin_sizes): + if any(size > 1 for size in prev_bin_sizes): + raise ValueError("Something went wrong, we put more than one sample per rank for small batch") + + return prev_bin_sizes + + +def _check_batch_size_in_tokens( + sorted_lengths:list[int], + max_tokens: int64, + minibatch_sizes:list[int], + ) -> bool: + + first_sample_idx = 0 + for bs in minibatch_sizes: + if bs > 0: + minibatch = sorted_lengths[first_sample_idx:first_sample_idx + bs] + if _batch_size_in_tokens(minibatch) >= max_tokens: + return False + first_sample_idx += bs + return True + + +def _compute_sample_indeces_for_current_rank( + lengths: list[int], + minibatches: list[list[int]], + rank: int64, + ) -> list[list[int]]: + + sorted_indices = np.argsort(lengths)[::-1] + minibatch_indices = [] + first_sample_idx = 0 + for minibatch in minibatches: + first_sample_idx += sum(minibatch[:rank]) + minibatch_indices.append(sorted_indices[first_sample_idx:first_sample_idx + minibatch[rank]].tolist()) + first_sample_idx += sum(minibatch[rank:]) + return minibatch_indices + + +def _compute_num_samples_in_grad_accum_step( + num_samples: int, + grad_accum : int + ) -> list[int]: + + if grad_accum <= 0: + return [] + if grad_accum == 1: + return [num_samples] + + step_size = num_samples // grad_accum + remainder = num_samples % grad_accum + result = [step_size] * grad_accum + result[-1] += remainder + return result + + +def _batch_packing_core_hpu( + lengths: list[int], + max_tokens: int64, + num_ranks: int64, + rank: int64, +) -> list[list[int]]: + + # try different gradient accumulation values + for grad_accum in [1, 2, 4]: + + # break input batch to several gradient accumulation steps + grad_accum_step_sizes = _compute_num_samples_in_grad_accum_step(len(lengths), grad_accum) + + first_sample_idx = 0 + minibatches = [] + for step_size in grad_accum_step_sizes: + step_lengths = lengths[first_sample_idx:first_sample_idx + step_size] + first_sample_idx += step_size + sorted_lengths = sorted(step_lengths, reverse=True) + + # find optimal sample distribution for single step based on computation cost + minibatch_sizes = _distribute_samples_across_ranks(sorted_lengths, max_tokens, num_ranks, rank) + if minibatch_sizes is None: + raise ValueError("Something went wrong") + + # check if found distribution fits in token limit + if not _check_batch_size_in_tokens(sorted_lengths, max_tokens, minibatch_sizes): + # does not fit, increase number of gradient accumulation steps + break + minibatches.append(minibatch_sizes) + + #check if we found suitable sample distribution + if len(minibatches) == grad_accum: + break + + # sanity check + if not ( + len(minibatches) == grad_accum and + all(len(minibatch) == num_ranks for minibatch in minibatches) and + sum(sum(minibatch) for minibatch in minibatches) == len(lengths) + ): + raise ValueError("Could not distribute samples across ranks") + + + # compute indices for current rank + minibatch_indices = _compute_sample_indeces_for_current_rank(lengths, minibatches, rank) + + # sanity check + from itertools import chain + all_indices = list(chain.from_iterable(minibatch_indices)) + if len(all_indices) != len(set(all_indices)): + raise ValueError("Something went wrong, duplicated indices in the list") + + # add one dummy sample to each empty minibatch + for minibatch in minibatch_indices: + if len(minibatch) == 0: + minibatch.append(-1) + + return minibatch_indices + + +def batch_lengths_to_minibatches_hpu( + batch_lengths: list[int], + max_tokens_per_rank: int, + num_ranks: int, + rank: int, +) -> list[list[int]]: + + if not batch_lengths: + return [] + result = _batch_packing_core_hpu( batch_lengths, max_tokens_per_rank, num_ranks, rank) + return result diff --git a/src/instructlab/training/sampler.py b/src/instructlab/training/sampler.py index be2b3f4b..766e9c27 100644 --- a/src/instructlab/training/sampler.py +++ b/src/instructlab/training/sampler.py @@ -14,6 +14,8 @@ from instructlab.training.batch_packer import batch_lengths_to_minibatches_lpt from instructlab.training.padded_batch_packer import ( batch_lengths_to_minibatches_padded, + batch_lengths_to_minibatches_hpu, + compute_bucket_size, ) from instructlab.training.type_definitions import CollatedItem @@ -24,10 +26,12 @@ class EpochSampler(Sampler): Replaces the naive distributed sampler with reproducible epoch-based shuffling. """ - def __init__(self, len_data: int, seed: int = 67, epoch: int = 0): + def __init__(self, len_data: int, seed: int = 67, epoch: int = 0, lengths: list[int] = None, num_chunks: int = 1): self.len_data = len_data self.seed = seed self._epoch = epoch + self.lengths = lengths + self.num_chunks = num_chunks @property def epoch(self) -> int: @@ -37,9 +41,32 @@ def set_epoch(self, epoch: int): self._epoch = epoch def generate_samples(self): - g = torch.Generator() - g.manual_seed(self.seed + self._epoch) - samples = torch.randperm(self.len_data, generator=g).tolist() + if self.num_chunks == 1 : + g = torch.Generator() + g.manual_seed(self.seed + self._epoch) + samples = torch.randperm(self.len_data, generator=g).tolist() + else: + sorted_indices = sorted(range(len(self.lengths)), key=lambda i: self.lengths[i], reverse=True) + + # break list of indices into several chunks + chunk_size = len(sorted_indices) // self.num_chunks + chunks = [] + for i in range(self.num_chunks): + if i == self.num_chunks - 1: + chunks.append(sorted_indices[i * chunk_size:]) + else: + chunks.append(sorted_indices[i * chunk_size:(i + 1) * chunk_size]) + + import random + random.seed(self.seed + self._epoch) + random.shuffle(chunks) + for chunk in chunks: + random.shuffle(chunk) + + # flatten the list + from itertools import chain + samples = list(chain.from_iterable(chunks)) + return samples def __iter__(self): @@ -113,7 +140,7 @@ def mb_collate_fn(minibatch, batch_num_loss_counted_tokens) -> CollatedItem: def padded_mb_collate_fn( - minibatch, batch_num_loss_counted_tokens, pad_token_id=0 + minibatch, batch_num_loss_counted_tokens, pad_token_id=0, use_hpu_packer: bool = False, ) -> CollatedItem: """Collates a list of samples into a padded batch for standard attention. @@ -148,6 +175,8 @@ def padded_mb_collate_fn( # Find max length in this batch max_len = max(len(item["input_ids"]) for item in minibatch) + if use_hpu_packer: + max_len = compute_bucket_size(max_len) # Prepare lists for batched tensors padded_input_ids = [] @@ -221,6 +250,7 @@ def __init__( dummy_sample=None, flash_enabled: bool = True, pad_token_id: int = 0, + use_hpu_packer: bool = False, ): self.max_tokens_per_rank = max_tokens_per_rank self.flash_enabled = flash_enabled @@ -247,11 +277,14 @@ def __init__( self.batch_packer = batch_lengths_to_minibatches_lpt self.collate_fn = mb_collate_fn else: - self.batch_packer = batch_lengths_to_minibatches_padded + if not use_hpu_packer: + self.batch_packer = batch_lengths_to_minibatches_padded + else: + self.batch_packer = batch_lengths_to_minibatches_hpu # Create a wrapper for padded collate that includes pad_token_id self.collate_fn = ( lambda minibatch, batch_num_loss_counted_tokens: padded_mb_collate_fn( - minibatch, batch_num_loss_counted_tokens, pad_token_id + minibatch, batch_num_loss_counted_tokens, pad_token_id, use_hpu_packer ) ) @@ -346,6 +379,8 @@ def get_data_loader( num_workers: int = 0, flash_enabled: bool = True, pad_token_id: int = 0, + num_chunks: int = 1, + use_hpu_packer: bool = False, ): """Create a data loader with epoch-based sampling and batch packing. @@ -360,12 +395,14 @@ def get_data_loader( num_workers: Number of data loading workers flash_enabled: Whether flash attention is enabled (affects collation strategy) pad_token_id: Token ID to use for padding (only used when flash_enabled=False) + num_chunks: Number of chunks to split dataset into for sequential training + use_hpu_packer: Use HPU specific packer Returns: DataLoader configured with appropriate collator based on flash_enabled """ dataset = TokenDataset(data_path) - sampler = EpochSampler(len(dataset), seed=seed) + sampler = EpochSampler(len(dataset), seed=seed, lengths = dataset.get_lengths(), num_chunks = num_chunks) # Create unified collator with appropriate mode collate_fn = MaxTokensPerRankCollator( @@ -375,6 +412,7 @@ def get_data_loader( dummy_sample=dummy_sample, flash_enabled=flash_enabled, pad_token_id=pad_token_id, + use_hpu_packer = use_hpu_packer, ) return DataLoader( diff --git a/src/instructlab/training/utils.py b/src/instructlab/training/utils.py index 275a4b7e..6ffad25a 100644 --- a/src/instructlab/training/utils.py +++ b/src/instructlab/training/utils.py @@ -318,6 +318,7 @@ def reduce_sum_forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + **(_deprecated_arguments if model.device=="hpu" else {}), ) return_dict = isinstance(output, dict) @@ -775,13 +776,16 @@ def _get_state_dict_patched(model, unwrap=False): accelerator.get_state_dict = get_state_dict_unpatched -def set_random_seed(seed): +def set_random_seed(seed, device: str): if seed is not None: random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - + if device == "hpu": + torch.hpu.manual_seed_all(seed) + else: + torch.cuda.manual_seed_all(seed) + # TODO: move this to also live in the `Model` object def save_checkpoint(