Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/instructlab/training/batch_loss_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def _compute_average_loss(
total_batch_loss = (
accumulated_loss * self.world_size / batch_num_loss_counted_tokens
)
if self.model.is_gpt_oss and accumulated_aux_loss is not None:
if accumulated_aux_loss is not None:
total_batch_loss += accumulated_aux_loss

# reduce across ranks
Expand Down
22 changes: 22 additions & 0 deletions src/instructlab/training/gpt_oss_utils_correct.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,28 @@ def is_gpt_oss(model_path_or_config: str | PretrainedConfig) -> bool:
return getattr(model_config, "model_type", None) == "gpt_oss"


def is_known_model(
model_path_or_config: str | PretrainedConfig, known_model_type: str | list[str]
) -> bool:
"""
Determine if the model is a known model.
"""
if not isinstance(model_path_or_config, (PretrainedConfig, str)):
raise ValueError(
f"cannot detect model: received invalid argument of type {type(model_path_or_config)}"
)

# convert to config
model_config = model_path_or_config
if isinstance(model_path_or_config, str):
model_config = AutoConfig.from_pretrained(model_path_or_config)

known_model_types = (
[known_model_type] if isinstance(known_model_type, str) else known_model_type
)
return getattr(model_config, "model_type", None) in known_model_types


def add_gpt_oss_quantization_config(config):
"""
Add GPT-OSS quantization configuration to a model config object.
Expand Down
13 changes: 7 additions & 6 deletions src/instructlab/training/main_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,12 +346,13 @@ def main(args):
# GPT-OSS specifically
# We don't want to use use_orig_params for GPT-OSS models
fsdp_should_use_orig_params = False
if m.is_gpt_oss:
logger.info("🎯 Detected GPT-OSS model - freezing router parameters")
freeze_router_params(m)
# For GPT-OSS, we need to use the original parameters so we can properly
# freeze the router parameters.
fsdp_should_use_orig_params = True
if m.is_gpt_oss or m.is_granitemoehybrid:
frozen_router_params = freeze_router_params(m)
if frozen_router_params:
logger.info("🎯 Detected an MoE model - frozen router parameters")
# For an MoE model, we need to use the original parameters so we can properly
# freeze the router parameters.
fsdp_should_use_orig_params = True

# Mini_trainer approach: simplified setup
# No complex calculations needed - the data loader handles everything
Expand Down
7 changes: 4 additions & 3 deletions src/instructlab/training/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
DistributedBackend,
Optimizer,
)
from instructlab.training.gpt_oss_utils_correct import is_gpt_oss
from instructlab.training.gpt_oss_utils_correct import is_gpt_oss, is_known_model
from instructlab.training.type_definitions import ModelInputs, ModelLosses


Expand All @@ -65,6 +65,7 @@ def __init__(
quant_config = None

# check model type & set on the mclasss
self.is_granitemoehybrid = is_known_model(model_path, "granitemoehybrid")
self.is_gpt_oss = is_gpt_oss(model_path)
if self.is_gpt_oss:
# Third Party
Expand Down Expand Up @@ -418,7 +419,7 @@ def compute_loss(

# add the MoE auxiliary loss (currently we only support this for GPT-OSS)
if (
self.is_gpt_oss
(self.is_gpt_oss or self.is_granitemoehybrid)
and hasattr(output, "aux_loss")
and output.aux_loss is not None
):
Expand All @@ -429,7 +430,7 @@ def compute_loss(
scaled_main_loss = primary_loss * world_size / samples_in_batch

# For GPT-OSS: add unscaled auxiliary loss after scaling main loss
if self.is_gpt_oss and aux_loss is not None:
if aux_loss is not None:
scaled_main_loss += aux_loss

raw_losses = ModelLosses(main_loss=primary_loss, aux_loss=aux_loss)
Expand Down
11 changes: 7 additions & 4 deletions src/instructlab/training/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -902,13 +902,13 @@ def load_latest_full_state(args, accelerator) -> None:

def freeze_router_params(model: Model):
"""
Freeze router parameters for GPT-OSS models before FSDP setup.
Freeze router parameters for MoE models before FSDP setup.

Args:
model: The model to check and potentially freeze parameters

Returns:
bool: True if this is a GPT-OSS model and parameters were frozen
bool: True if this is an MoE model and parameters were frozen
"""

# Freeze router parameters BEFORE accelerator setup
Expand All @@ -919,8 +919,11 @@ def freeze_router_params(model: Model):
frozen_count += 1
logger.info(f"❄️ Frozen router parameter: {name}")

logger.info(f"✅ Frozen {frozen_count} router parameters for GPT-OSS model")
return True
if frozen_count > 0:
logger.info(f"✅ Frozen {frozen_count} router parameters for an MoE model")
return True
else:
return False


def test_model_inference_quick(model, tokenizer, stage_name):
Expand Down
Loading