Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion src/instructlab/training/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def __init__(
deepspeed_cpu_offload_optimizer_ratio: Optional[float] = None,
fsdp_cpu_offload_params: Optional[bool] = False,
fsdp_use_orig_params: Optional[bool] = False,
device: Optional[str] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of just str, you might want to create a new type that constrains the choices between cuda and hpu:

AcceleratorDevice = literal["cuda", "hpu"]

# ...

# safe to infer cuda, since this will be the most common option
device: Optional[AcceleratorDevice] = "cuda"  

):
self.samples_per_gpu = samples_per_gpu
self.save_samples = save_samples
Expand All @@ -61,6 +62,7 @@ def __init__(
self.fsdp_cpu_offload_params = fsdp_cpu_offload_params
self.fsdp_use_orig_params = fsdp_use_orig_params
self.lr_scheduler = None
self.device_str = device #should be before first use, that happens in self.get_fsdp_config()
if self.distributed_framework == DistributedBackend.DEEPSPEED:
# Standard
accel_args = {
Expand All @@ -81,6 +83,10 @@ def __init__(
"fsdp_plugin": self.get_fsdp_config(),
"mixed_precision": "bf16",
}
if device == "hpu":
from optimum.habana.accelerate import GaudiAccelerator as TransformersAccel
else:
from accelerate import Accelerator as TransformersAccel
self.accelerator = TransformersAccel(
**accel_args,
)
Expand Down Expand Up @@ -159,7 +165,9 @@ def get_fsdp_config(self):
use_orig_params=self.fsdp_use_orig_params,
# TODO(osilkin): expose switch for fp32 reduction
)

if self.device_str == "hpu":
fsdp_plugin.use_orig_params=True
fsdp_plugin.sync_module_states=True
return fsdp_plugin

def get_ds_plugin(
Expand Down
38 changes: 36 additions & 2 deletions src/instructlab/training/batch_loss_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class BatchLossManager:
- Computing average losses for logging
"""

def __init__(self, model, accelerator, world_size: int, local_rank: int):
def __init__(self, model, accelerator, world_size: int, local_rank: int, device: str, save_grads: bool):
"""
Initialize the BatchLossManager.

Expand All @@ -60,7 +60,12 @@ def __init__(self, model, accelerator, world_size: int, local_rank: int):
self.accelerator: Accelerator = accelerator
self.world_size: int = world_size
self.local_rank: int = local_rank
self.torch_device = torch.device("cuda", local_rank)
if device == "hpu":
self.torch_device = torch.device("hpu")
else:
self.torch_device = torch.device("cuda", local_rank)
self.save_grads = save_grads
self.grad_buffer = {}

def process_batch(self, batch: list[CollatedItem]) -> tuple[BatchMetrics, float]:
"""
Expand All @@ -84,6 +89,7 @@ def process_batch(self, batch: list[CollatedItem]) -> tuple[BatchMetrics, float]
grad_accum_steps = 0

# process each minibatch
self.grad_buffer = {}
for mb in batch:
# extract minibatch-specific info
micro_batch_size = mb["num_samples"]
Expand All @@ -102,12 +108,21 @@ def process_batch(self, batch: list[CollatedItem]) -> tuple[BatchMetrics, float]
)
self.accelerator.backward(scaled_loss)

# save gradients
if self.save_grads and len(batch) > 1:
self._copy_grads_to_buffer()
self._zero_model_grads()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@miseshkiwrk How come we need to move the gradients back and forth between the buffer? Does HPU not allow accumulating grads directly in the current objects?

If this is the case, could you please leave a block comment somewhere describing the issue and why this is needed? Thanks in advance!


# accumulate losses
grad_accum_steps += 1
accumulated_loss += raw_losses.main_loss
if raw_losses.aux_loss is not None:
accumulated_aux_loss += raw_losses.aux_loss

# restore gradients
if self.grad_buffer:
self._restore_grads_from_buffer()

# reduce metrics across ranks
batch_total_samples, batch_total_length = self._reduce_metrics(
batch_total_samples, batch_total_length
Expand Down Expand Up @@ -186,3 +201,22 @@ def _compute_average_loss(
).item()

return avg_loss_across_ranks

def _copy_grads_to_buffer(self):
for p in self.model.parameters():
if p.grad is None:
continue
if p not in self.grad_buffer:
self.grad_buffer[p] = p.grad.detach().clone()
else:
self.grad_buffer[p] += p.grad.detach()

def _restore_grads_from_buffer(self):
for p in self.model.parameters():
if p in self.grad_buffer:
p.grad = self.grad_buffer[p]

def _zero_model_grads(self):
for p in self.model.parameters():
if p.grad is not None:
p.grad = None
5 changes: 5 additions & 0 deletions src/instructlab/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,3 +254,8 @@ class TrainingArgs(BaseModel):
log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = Field(
default="INFO"
)

device: Optional[str] = None
torch_compile: bool = False
num_chunks: int = 1

Comment on lines +258 to +261
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these are good to expose, but what are your thoughts if we adjusted the field names like this?

  • device --> device_type (a constrained string of cuda | hpu, defaults to cuda), to be clearer on what the underlying hardware looks like, rather than a specific device we train on
  • num_chunks --> padded_packing_num_sorting_bins this field represents a very specific implementation detail, so a longer name is fitting. A description should be added to explain that this helps training efficiency when packing has to be used, at the cost of randomness/increased bias.

Would it also make sense to have it be enable_torch_compile or use_torch_compile?

72 changes: 59 additions & 13 deletions src/instructlab/training/main_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def train(
global_grad_norm = None

# Initialize the batch loss manager
batch_loss_manager = BatchLossManager(model, accelerator, world_size, local_rank)
batch_loss_manager = BatchLossManager(model, accelerator, world_size, local_rank, args.device, args.device=="hpu" and args.torch_compile)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@splotnikv Since you're setting the save_grads kwarg conditionally and inline, would you mind specifying the current args by their kwargs explicitly? This way it'll be easier to manage and maintain in the future.


# Blast through batches
for epoch in range(args.current_epoch, args.num_epochs):
Expand Down Expand Up @@ -150,8 +150,12 @@ def train(
elapsed_time = time.time() - start
overall_throughput = batch_metrics.total_samples / elapsed_time
current_lr = accelerator.lr_scheduler.get_last_lr()[0]
cuda_mem_allocated = torch.cuda.memory_allocated() / (1024**3)
cuda_malloc_retries = torch.cuda.memory_stats()["num_alloc_retries"]
if args.device == "hpu":
mem_allocated = torch.hpu.memory_allocated() / (1024**3)
malloc_retries = 0
else:
mem_allocated = torch.cuda.memory_allocated() / (1024**3)
malloc_retries = torch.cuda.memory_stats()["num_alloc_retries"]
global_grad_norm = (
model.get_global_grad_norm()
if hasattr(model, "get_global_grad_norm")
Expand All @@ -173,8 +177,8 @@ def train(
"rank": dist.get_rank(),
"overall_throughput": overall_throughput,
"lr": current_lr,
"cuda_mem_allocated": cuda_mem_allocated,
"cuda_malloc_retries": cuda_malloc_retries,
("hpu" if args.device == "hpu" else "cuda") + "_mem_allocated": mem_allocated,
("hpu" if args.device == "hpu" else "cuda") + "_malloc_retries": malloc_retries,
"num_loss_counted_tokens": batch_metrics.num_loss_counted_tokens,
"num_tokens_rank0": batch_metrics.total_length,
"batch_size": batch_metrics.total_samples,
Expand Down Expand Up @@ -206,7 +210,8 @@ def train(
global_step += 1
if local_rank == 0:
inner_pb.update(1)
torch.cuda.empty_cache()
if args.device != "hpu":
torch.cuda.empty_cache()
if args.checkpoint_at_epoch:
base_logger.debug(f"Saving checkpoint at epoch {epoch}")
save_checkpoint(
Expand Down Expand Up @@ -284,17 +289,22 @@ def main(args):
args.model_type = model_conf.model_type

#### distributed init #####
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
if args.device == "hpu":
torch.hpu.set_device(int(os.environ["LOCAL_RANK"]))
else:
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
args.local_rank = int(os.environ["LOCAL_RANK"])

timeout = _get_collective_timeout()
if timeout is not None:
dist.init_process_group(timeout=timeout)
else:
dist.init_process_group()
backend = "hccl" if args.device == "hpu" else None
torch.distributed.init_process_group(backend=backend, timeout=timeout)


args.global_rank = dist.get_rank()
tensor = torch.ByteTensor([False]).cuda()
if args.device == "hpu":
tensor = torch.ByteTensor([False]).to('hpu')
else:
tensor = torch.ByteTensor([False]).cuda()
dist.all_reduce(tensor)
dist.barrier()

Expand Down Expand Up @@ -335,6 +345,8 @@ def main(args):
flash_enabled=flash_enabled,
noise_alpha=args.NEFTune_alpha,
lora_quant_bits=args.lora_quant_bits,
device=args.device,
torch_compile=args.torch_compile,
)

args.base_model_args = m.base_model_args
Expand Down Expand Up @@ -372,6 +384,8 @@ def main(args):
num_workers=8, # I don't like this but am setting it for consistency
flash_enabled=flash_enabled,
pad_token_id=pad_token_id,
num_chunks=args.num_chunks,
use_hpu_packer=(args.device=="hpu"),
)

if args.local_rank == 0:
Expand Down Expand Up @@ -410,6 +424,7 @@ def main(args):
fsdp_cpu_offload_params=args.cpu_offload_params_fsdp,
save_samples=args.save_samples,
fsdp_use_orig_params=fsdp_should_use_orig_params,
device=args.device,
)
# optimizer needs model that has been prepared by accelerator
# and then accelerator needs to be prepared AGAIN once optimizer is initialized
Expand Down Expand Up @@ -588,6 +603,12 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
if train_args.keep_last_checkpoint_only:
command.append("--keep_last_checkpoint_only")

command.append(f"--device={train_args.device}")
if train_args.torch_compile:
command.append("--torch-compile")
command.append(f"--num-chunks={train_args.num_chunks}")


logger.info("Running training command as subprocess: %s", " ".join(command))
process = None
interrupt: KeyboardInterrupt | Exception | None = None
Expand Down Expand Up @@ -789,8 +810,33 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
action="store_true",
help="Use Liger kernels for training.",
)
parser.add_argument(
"--device",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the context of PyTorch, we typically assume device refers to a singular device on a node containing several accelerators, e.g.: cuda:0, cuda:1, etc.

Since you're looking to allow for switching between HPU and CUDA style training, maybe we could change this to be --device_type? This way, we can also constrain the options to be either cuda (by default) or hpu, and error out when given an option which doesn't fall in the set of allowed devices.

type=str,
default=None,
help="PyTorch device to use.",
)
parser.add_argument(
'--torch-compile',
action='store_true', default=False,
help='Enable torch.compile, hpu only.'
)
parser.add_argument(
'--num-chunks',
type=int,
default=1,
help='Number of chunks to split dataset into for sequential training.'
)

args = parser.parse_args()
set_random_seed(args.seed)

if args.device == "hpu":
import habana_frameworks.torch.core as htcore
import habana_frameworks.torch.distributed.hccl
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
adapt_transformers_to_gaudi()

set_random_seed(args.seed, args.device)
main(args)

"""
Expand Down
48 changes: 46 additions & 2 deletions src/instructlab/training/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
# Third Party
from peft import LoraConfig
from torch.optim import AdamW
from transformers import Mxfp4Config # pylint: disable=no-name-in-module
from transformers import (
AutoModelForCausalLM,
BitsAndBytesConfig,
Expand All @@ -57,17 +56,22 @@ def __init__(
flash_enabled: bool = False,
lora_config: Optional[LoraConfig] = None,
lora_quant_bits: int = 0,
device: Optional[str] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@splotnikv Is device here intended to be a specific device, or the broader device type (cuda/hpu)? If it's the latter, we should rename this to device_type to be consistent with the expected usage. Otherwise, it will be easy to get this confused with a specific torch.device instance such as cuda:0.

torch_compile: bool = False,
):
self.lora_config = lora_config
self.noise_alpha = noise_alpha
self.tokenizer = tokenizer
self.distributed_framework = distributed_framework
self.device = device
self.torch_compile = torch_compile
quant_config = None

# check model type & set on the mclasss
self.is_gpt_oss = is_gpt_oss(model_path)
if self.is_gpt_oss:
# Third Party
from transformers import Mxfp4Config # pylint: disable=no-name-in-module
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@splotnikv How come this import is being moved from module-level to method-level?

quant_config = Mxfp4Config(dequantize=True)

# TODO: Add support for 8bit quantization
Expand Down Expand Up @@ -102,6 +106,19 @@ def __init__(

def _post_model_init(self):
"""Common initialization steps that should happen after model initialization."""

if self.device == "hpu" and self.torch_compile:
cache_size_limit = 10*1000
torch._dynamo.config.cache_size_limit = cache_size_limit
torch._dynamo.config.accumulated_cache_size_limit = 2*cache_size_limit
self.model = torch.compile(self.model, backend="hpu_backend", dynamic=False)
for layer in self.model.model.layers:
layer.compile(backend="hpu_backend", dynamic=False)
if os.environ.get("RANK", '0') == '0':
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@splotnikv Are you trying to log from the main process on a given node here? The RANK env here corresponds to the rank of the node, not the process. If you just want to log a message on the main process (local rank 0), consider using the log_rank_0 function from instructlab.training.utils.

logger.info(
f"torch.compile has been enabled"
)

self.reconcile_tokenizer()
if self.lora_config:
self.model = self.prepare_peft_model()
Expand Down Expand Up @@ -270,7 +287,11 @@ def _is_causal_lm_model(self) -> bool:
bool: True if the model is a causal language model, False otherwise.
"""
# Third Party
return "ForCausalLM" in self.model.__class__.__name__
if self.device != "hpu":
class_name = self.model.__class__.__name__
else:
class_name = self.model._orig_mod.__class__.__name__ if self.model.__class__.__name__ == 'OptimizedModule' else self.model.__class__.__name__
return "ForCausalLM" in class_name

def reconcile_tokenizer(self):
if len(self.tokenizer) > self.model.config.vocab_size:
Expand Down Expand Up @@ -326,6 +347,17 @@ def reconcile_tokenizer(self):
):
self.model.config.eos_token_id = self.tokenizer.eos_token_id

if self.device == "hpu":
model = self.model._orig_mod if self.model.__class__.__name__ == 'OptimizedModule' else self.model
class_name = model.__class__.__name__

replace_no_split_modules = {
'GaudiLlamaForCausalLM': ['GaudiLlamaDecoderLayer',]
}

if class_name in replace_no_split_modules:
model._no_split_modules = replace_no_split_modules[class_name]

if not self._is_causal_lm_model():
raise ValueError(
f"Model must be a causal language model, got {type(self.model)}"
Expand Down Expand Up @@ -386,9 +418,17 @@ def compute_loss(
- Dataclass containing the raw pre-scaled losses
"""
# Forward pass to get logits
hpu_args = {}
if self.device == "hpu":
hpu_args = {
"use_flash_attention":True,
"lazy_mode":False,
}

output = self(
**inputs,
use_cache=False,
**hpu_args,
)

# Manual loss computation with reduction="none" following mini_trainer's exact approach
Expand Down Expand Up @@ -490,6 +530,8 @@ def __init__(
flash_enabled: bool = False,
lora_config: Optional[LoraConfig] = None,
lora_quant_bits: int = 0,
device: Optional[str] = None,
torch_compile: bool = False,
):
super().__init__(
model_path=model_path,
Expand All @@ -499,6 +541,8 @@ def __init__(
flash_enabled=flash_enabled,
lora_config=lora_config,
lora_quant_bits=lora_quant_bits,
device=device,
torch_compile=torch_compile,
)
self.model = AutoModelForCausalLM.from_pretrained(**self.base_model_args)
self._post_model_init()
Expand Down
Loading
Loading