Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/instructlab/training/batch_loss_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def _compute_average_loss(
total_batch_loss = (
accumulated_loss * self.world_size / batch_num_loss_counted_tokens
)
if self.model.is_gpt_oss and accumulated_aux_loss is not None:
if accumulated_aux_loss is not None:
total_batch_loss += accumulated_aux_loss

# reduce across ranks
Expand Down
14 changes: 13 additions & 1 deletion src/instructlab/training/gpt_oss_utils_correct.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,15 @@ def is_gpt_oss(model_path_or_config: str | PretrainedConfig) -> bool:
"""
Determine if we should convert GPT-OSS format during saving.
"""
return is_known_model(model_path_or_config, "gpt_oss")


def is_known_model(
model_path_or_config: str | PretrainedConfig, known_model_type: str | list[str]
) -> bool:
"""
Determine if the model is a known model.
"""
if not isinstance(model_path_or_config, (PretrainedConfig, str)):
raise ValueError(
f"cannot detect model: received invalid argument of type {type(model_path_or_config)}"
Expand All @@ -408,7 +417,10 @@ def is_gpt_oss(model_path_or_config: str | PretrainedConfig) -> bool:
if isinstance(model_path_or_config, str):
model_config = AutoConfig.from_pretrained(model_path_or_config)

return getattr(model_config, "model_type", None) == "gpt_oss"
known_model_types = (
[known_model_type] if isinstance(known_model_type, str) else known_model_type
)
return getattr(model_config, "model_type", None) in known_model_types


def add_gpt_oss_quantization_config(config):
Expand Down
13 changes: 7 additions & 6 deletions src/instructlab/training/main_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,12 +346,13 @@ def main(args):
# GPT-OSS specifically
# We don't want to use use_orig_params for GPT-OSS models
fsdp_should_use_orig_params = False
if m.is_gpt_oss:
logger.info("🎯 Detected GPT-OSS model - freezing router parameters")
freeze_router_params(m)
# For GPT-OSS, we need to use the original parameters so we can properly
# freeze the router parameters.
fsdp_should_use_orig_params = True
if m.is_gpt_oss or m.is_granitemoehybrid:
frozen_router_params = freeze_router_params(m)
if frozen_router_params:
logger.info("🎯 Detected an MoE model - frozen router parameters")
# For an MoE model, we need to use the original parameters so we can properly
# freeze the router parameters.
fsdp_should_use_orig_params = True

# Mini_trainer approach: simplified setup
# No complex calculations needed - the data loader handles everything
Expand Down
7 changes: 4 additions & 3 deletions src/instructlab/training/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
DistributedBackend,
Optimizer,
)
from instructlab.training.gpt_oss_utils_correct import is_gpt_oss
from instructlab.training.gpt_oss_utils_correct import is_gpt_oss, is_known_model
from instructlab.training.type_definitions import ModelInputs, ModelLosses


Expand All @@ -65,6 +65,7 @@ def __init__(
quant_config = None

# check model type & set on the mclasss
self.is_granitemoehybrid = is_known_model(model_path, "granitemoehybrid")
self.is_gpt_oss = is_gpt_oss(model_path)
if self.is_gpt_oss:
# Third Party
Expand Down Expand Up @@ -418,7 +419,7 @@ def compute_loss(

# add the MoE auxiliary loss (currently we only support this for GPT-OSS)
if (
self.is_gpt_oss
(self.is_gpt_oss or self.is_granitemoehybrid)
and hasattr(output, "aux_loss")
and output.aux_loss is not None
):
Expand All @@ -429,7 +430,7 @@ def compute_loss(
scaled_main_loss = primary_loss * world_size / samples_in_batch

# For GPT-OSS: add unscaled auxiliary loss after scaling main loss
if self.is_gpt_oss and aux_loss is not None:
if aux_loss is not None:
scaled_main_loss += aux_loss

raw_losses = ModelLosses(main_loss=primary_loss, aux_loss=aux_loss)
Expand Down
11 changes: 7 additions & 4 deletions src/instructlab/training/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -902,13 +902,13 @@ def load_latest_full_state(args, accelerator) -> None:

def freeze_router_params(model: Model):
"""
Freeze router parameters for GPT-OSS models before FSDP setup.
Freeze router parameters for MoE models before FSDP setup.

Args:
model: The model to check and potentially freeze parameters

Returns:
bool: True if this is a GPT-OSS model and parameters were frozen
bool: True if this is an MoE model and parameters were frozen
"""

# Freeze router parameters BEFORE accelerator setup
Expand All @@ -919,8 +919,11 @@ def freeze_router_params(model: Model):
frozen_count += 1
logger.info(f"❄️ Frozen router parameter: {name}")

logger.info(f"✅ Frozen {frozen_count} router parameters for GPT-OSS model")
return True
if frozen_count > 0:
logger.info(f"✅ Frozen {frozen_count} router parameters for an MoE model")
return True
else:
return False


def test_model_inference_quick(model, tokenizer, stage_name):
Expand Down
Loading