From 161ae8061d6508d9e47c49e9f72172fb951c444e Mon Sep 17 00:00:00 2001 From: Mikio Takeuchi Date: Tue, 18 Nov 2025 17:13:33 +0900 Subject: [PATCH 1/5] Support granite 4 models as MoE models --- src/instructlab/training/batch_loss_manager.py | 2 +- .../training/gpt_oss_utils_correct.py | 18 ++++++++++++++++++ src/instructlab/training/main_ds.py | 13 +++++++------ src/instructlab/training/model.py | 7 ++++--- src/instructlab/training/utils.py | 11 +++++++---- 5 files changed, 37 insertions(+), 14 deletions(-) diff --git a/src/instructlab/training/batch_loss_manager.py b/src/instructlab/training/batch_loss_manager.py index f0e10a89..cc6da021 100644 --- a/src/instructlab/training/batch_loss_manager.py +++ b/src/instructlab/training/batch_loss_manager.py @@ -174,7 +174,7 @@ def _compute_average_loss( total_batch_loss = ( accumulated_loss * self.world_size / batch_num_loss_counted_tokens ) - if self.model.is_gpt_oss and accumulated_aux_loss is not None: + if accumulated_aux_loss is not None: total_batch_loss += accumulated_aux_loss # reduce across ranks diff --git a/src/instructlab/training/gpt_oss_utils_correct.py b/src/instructlab/training/gpt_oss_utils_correct.py index 430a890f..e4e7c35d 100644 --- a/src/instructlab/training/gpt_oss_utils_correct.py +++ b/src/instructlab/training/gpt_oss_utils_correct.py @@ -411,6 +411,24 @@ def is_gpt_oss(model_path_or_config: str | PretrainedConfig) -> bool: return getattr(model_config, "model_type", None) == "gpt_oss" +def is_known_model(model_path_or_config: str | PretrainedConfig, known_model_type: str | list[str]) -> bool: + """ + Determine if the model is a known model. + """ + if not isinstance(model_path_or_config, (PretrainedConfig, str)): + raise ValueError( + f"cannot detect model: received invalid argument of type {type(model_path_or_config)}" + ) + + # convert to config + model_config = model_path_or_config + if isinstance(model_path_or_config, str): + model_config = AutoConfig.from_pretrained(model_path_or_config) + + known_model_types = [known_model_type] if isinstance(known_model_type, str) else known_model_type + return getattr(model_config, "model_type", None) in known_model_types + + def add_gpt_oss_quantization_config(config): """ Add GPT-OSS quantization configuration to a model config object. diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index a0931353..0c4740f5 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -346,12 +346,13 @@ def main(args): # GPT-OSS specifically # We don't want to use use_orig_params for GPT-OSS models fsdp_should_use_orig_params = False - if m.is_gpt_oss: - logger.info("🎯 Detected GPT-OSS model - freezing router parameters") - freeze_router_params(m) - # For GPT-OSS, we need to use the original parameters so we can properly - # freeze the router parameters. - fsdp_should_use_orig_params = True + if m.is_gpt_oss or m.is_granitemoehybrid: # NOTE is this guard needed? + frozen_router_params = freeze_router_params(m) + if frozen_router_params: + logger.info("🎯 Detected an MoE model - frozen router parameters") + # For an MoE model, we need to use the original parameters so we can properly + # freeze the router parameters. + fsdp_should_use_orig_params = True # Mini_trainer approach: simplified setup # No complex calculations needed - the data loader handles everything diff --git a/src/instructlab/training/model.py b/src/instructlab/training/model.py index de863e1d..75681cb0 100644 --- a/src/instructlab/training/model.py +++ b/src/instructlab/training/model.py @@ -43,7 +43,7 @@ DistributedBackend, Optimizer, ) -from instructlab.training.gpt_oss_utils_correct import is_gpt_oss +from instructlab.training.gpt_oss_utils_correct import is_gpt_oss, is_known_model from instructlab.training.type_definitions import ModelInputs, ModelLosses @@ -65,6 +65,7 @@ def __init__( quant_config = None # check model type & set on the mclasss + self.is_granitemoehybrid = is_known_model(model_path, "granitemoehybrid") self.is_gpt_oss = is_gpt_oss(model_path) if self.is_gpt_oss: # Third Party @@ -418,7 +419,7 @@ def compute_loss( # add the MoE auxiliary loss (currently we only support this for GPT-OSS) if ( - self.is_gpt_oss + (self.is_gpt_oss or self.is_granitemoehybrid) # NOTE is this guard needed? and hasattr(output, "aux_loss") and output.aux_loss is not None ): @@ -429,7 +430,7 @@ def compute_loss( scaled_main_loss = primary_loss * world_size / samples_in_batch # For GPT-OSS: add unscaled auxiliary loss after scaling main loss - if self.is_gpt_oss and aux_loss is not None: + if aux_loss is not None: scaled_main_loss += aux_loss raw_losses = ModelLosses(main_loss=primary_loss, aux_loss=aux_loss) diff --git a/src/instructlab/training/utils.py b/src/instructlab/training/utils.py index 275a4b7e..fc31858e 100644 --- a/src/instructlab/training/utils.py +++ b/src/instructlab/training/utils.py @@ -902,13 +902,13 @@ def load_latest_full_state(args, accelerator) -> None: def freeze_router_params(model: Model): """ - Freeze router parameters for GPT-OSS models before FSDP setup. + Freeze router parameters for MoE models before FSDP setup. Args: model: The model to check and potentially freeze parameters Returns: - bool: True if this is a GPT-OSS model and parameters were frozen + bool: True if this is an MoE model and parameters were frozen """ # Freeze router parameters BEFORE accelerator setup @@ -919,8 +919,11 @@ def freeze_router_params(model: Model): frozen_count += 1 logger.info(f"❄️ Frozen router parameter: {name}") - logger.info(f"✅ Frozen {frozen_count} router parameters for GPT-OSS model") - return True + if frozen_count > 0: + logger.info(f"✅ Frozen {frozen_count} router parameters for an MoE model") + return True + else: + return False def test_model_inference_quick(model, tokenizer, stage_name): From e8f89223d72dc56e6d21ef9154a98ed6e0e7a038 Mon Sep 17 00:00:00 2001 From: Mikio Takeuchi Date: Tue, 18 Nov 2025 22:15:21 +0900 Subject: [PATCH 2/5] fix ruff errors --- src/instructlab/training/gpt_oss_utils_correct.py | 8 ++++++-- src/instructlab/training/main_ds.py | 2 +- src/instructlab/training/model.py | 2 +- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/instructlab/training/gpt_oss_utils_correct.py b/src/instructlab/training/gpt_oss_utils_correct.py index e4e7c35d..7cf93300 100644 --- a/src/instructlab/training/gpt_oss_utils_correct.py +++ b/src/instructlab/training/gpt_oss_utils_correct.py @@ -411,7 +411,9 @@ def is_gpt_oss(model_path_or_config: str | PretrainedConfig) -> bool: return getattr(model_config, "model_type", None) == "gpt_oss" -def is_known_model(model_path_or_config: str | PretrainedConfig, known_model_type: str | list[str]) -> bool: +def is_known_model( + model_path_or_config: str | PretrainedConfig, known_model_type: str | list[str] +) -> bool: """ Determine if the model is a known model. """ @@ -425,7 +427,9 @@ def is_known_model(model_path_or_config: str | PretrainedConfig, known_model_typ if isinstance(model_path_or_config, str): model_config = AutoConfig.from_pretrained(model_path_or_config) - known_model_types = [known_model_type] if isinstance(known_model_type, str) else known_model_type + known_model_types = ( + [known_model_type] if isinstance(known_model_type, str) else known_model_type + ) return getattr(model_config, "model_type", None) in known_model_types diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index 0c4740f5..b73a86f3 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -346,7 +346,7 @@ def main(args): # GPT-OSS specifically # We don't want to use use_orig_params for GPT-OSS models fsdp_should_use_orig_params = False - if m.is_gpt_oss or m.is_granitemoehybrid: # NOTE is this guard needed? + if m.is_gpt_oss or m.is_granitemoehybrid: frozen_router_params = freeze_router_params(m) if frozen_router_params: logger.info("🎯 Detected an MoE model - frozen router parameters") diff --git a/src/instructlab/training/model.py b/src/instructlab/training/model.py index 75681cb0..bb89a1c4 100644 --- a/src/instructlab/training/model.py +++ b/src/instructlab/training/model.py @@ -419,7 +419,7 @@ def compute_loss( # add the MoE auxiliary loss (currently we only support this for GPT-OSS) if ( - (self.is_gpt_oss or self.is_granitemoehybrid) # NOTE is this guard needed? + (self.is_gpt_oss or self.is_granitemoehybrid) and hasattr(output, "aux_loss") and output.aux_loss is not None ): From da1c034843184b33f0fb277149d9781cd53a4942 Mon Sep 17 00:00:00 2001 From: Mikio Takeuchi Date: Tue, 18 Nov 2025 17:13:33 +0900 Subject: [PATCH 3/5] Support granite 4 models as MoE models --- src/instructlab/training/batch_loss_manager.py | 2 +- .../training/gpt_oss_utils_correct.py | 18 ++++++++++++++++++ src/instructlab/training/main_ds.py | 13 +++++++------ src/instructlab/training/model.py | 7 ++++--- src/instructlab/training/utils.py | 11 +++++++---- 5 files changed, 37 insertions(+), 14 deletions(-) diff --git a/src/instructlab/training/batch_loss_manager.py b/src/instructlab/training/batch_loss_manager.py index f0e10a89..cc6da021 100644 --- a/src/instructlab/training/batch_loss_manager.py +++ b/src/instructlab/training/batch_loss_manager.py @@ -174,7 +174,7 @@ def _compute_average_loss( total_batch_loss = ( accumulated_loss * self.world_size / batch_num_loss_counted_tokens ) - if self.model.is_gpt_oss and accumulated_aux_loss is not None: + if accumulated_aux_loss is not None: total_batch_loss += accumulated_aux_loss # reduce across ranks diff --git a/src/instructlab/training/gpt_oss_utils_correct.py b/src/instructlab/training/gpt_oss_utils_correct.py index 430a890f..e4e7c35d 100644 --- a/src/instructlab/training/gpt_oss_utils_correct.py +++ b/src/instructlab/training/gpt_oss_utils_correct.py @@ -411,6 +411,24 @@ def is_gpt_oss(model_path_or_config: str | PretrainedConfig) -> bool: return getattr(model_config, "model_type", None) == "gpt_oss" +def is_known_model(model_path_or_config: str | PretrainedConfig, known_model_type: str | list[str]) -> bool: + """ + Determine if the model is a known model. + """ + if not isinstance(model_path_or_config, (PretrainedConfig, str)): + raise ValueError( + f"cannot detect model: received invalid argument of type {type(model_path_or_config)}" + ) + + # convert to config + model_config = model_path_or_config + if isinstance(model_path_or_config, str): + model_config = AutoConfig.from_pretrained(model_path_or_config) + + known_model_types = [known_model_type] if isinstance(known_model_type, str) else known_model_type + return getattr(model_config, "model_type", None) in known_model_types + + def add_gpt_oss_quantization_config(config): """ Add GPT-OSS quantization configuration to a model config object. diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index a0931353..0c4740f5 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -346,12 +346,13 @@ def main(args): # GPT-OSS specifically # We don't want to use use_orig_params for GPT-OSS models fsdp_should_use_orig_params = False - if m.is_gpt_oss: - logger.info("🎯 Detected GPT-OSS model - freezing router parameters") - freeze_router_params(m) - # For GPT-OSS, we need to use the original parameters so we can properly - # freeze the router parameters. - fsdp_should_use_orig_params = True + if m.is_gpt_oss or m.is_granitemoehybrid: # NOTE is this guard needed? + frozen_router_params = freeze_router_params(m) + if frozen_router_params: + logger.info("🎯 Detected an MoE model - frozen router parameters") + # For an MoE model, we need to use the original parameters so we can properly + # freeze the router parameters. + fsdp_should_use_orig_params = True # Mini_trainer approach: simplified setup # No complex calculations needed - the data loader handles everything diff --git a/src/instructlab/training/model.py b/src/instructlab/training/model.py index de863e1d..75681cb0 100644 --- a/src/instructlab/training/model.py +++ b/src/instructlab/training/model.py @@ -43,7 +43,7 @@ DistributedBackend, Optimizer, ) -from instructlab.training.gpt_oss_utils_correct import is_gpt_oss +from instructlab.training.gpt_oss_utils_correct import is_gpt_oss, is_known_model from instructlab.training.type_definitions import ModelInputs, ModelLosses @@ -65,6 +65,7 @@ def __init__( quant_config = None # check model type & set on the mclasss + self.is_granitemoehybrid = is_known_model(model_path, "granitemoehybrid") self.is_gpt_oss = is_gpt_oss(model_path) if self.is_gpt_oss: # Third Party @@ -418,7 +419,7 @@ def compute_loss( # add the MoE auxiliary loss (currently we only support this for GPT-OSS) if ( - self.is_gpt_oss + (self.is_gpt_oss or self.is_granitemoehybrid) # NOTE is this guard needed? and hasattr(output, "aux_loss") and output.aux_loss is not None ): @@ -429,7 +430,7 @@ def compute_loss( scaled_main_loss = primary_loss * world_size / samples_in_batch # For GPT-OSS: add unscaled auxiliary loss after scaling main loss - if self.is_gpt_oss and aux_loss is not None: + if aux_loss is not None: scaled_main_loss += aux_loss raw_losses = ModelLosses(main_loss=primary_loss, aux_loss=aux_loss) diff --git a/src/instructlab/training/utils.py b/src/instructlab/training/utils.py index 275a4b7e..fc31858e 100644 --- a/src/instructlab/training/utils.py +++ b/src/instructlab/training/utils.py @@ -902,13 +902,13 @@ def load_latest_full_state(args, accelerator) -> None: def freeze_router_params(model: Model): """ - Freeze router parameters for GPT-OSS models before FSDP setup. + Freeze router parameters for MoE models before FSDP setup. Args: model: The model to check and potentially freeze parameters Returns: - bool: True if this is a GPT-OSS model and parameters were frozen + bool: True if this is an MoE model and parameters were frozen """ # Freeze router parameters BEFORE accelerator setup @@ -919,8 +919,11 @@ def freeze_router_params(model: Model): frozen_count += 1 logger.info(f"❄️ Frozen router parameter: {name}") - logger.info(f"✅ Frozen {frozen_count} router parameters for GPT-OSS model") - return True + if frozen_count > 0: + logger.info(f"✅ Frozen {frozen_count} router parameters for an MoE model") + return True + else: + return False def test_model_inference_quick(model, tokenizer, stage_name): From c97c9ba521b3445c44a7db8e8d78cc06a8663231 Mon Sep 17 00:00:00 2001 From: Mikio Takeuchi Date: Tue, 18 Nov 2025 22:15:21 +0900 Subject: [PATCH 4/5] fix ruff errors --- src/instructlab/training/gpt_oss_utils_correct.py | 8 ++++++-- src/instructlab/training/main_ds.py | 2 +- src/instructlab/training/model.py | 2 +- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/instructlab/training/gpt_oss_utils_correct.py b/src/instructlab/training/gpt_oss_utils_correct.py index e4e7c35d..7cf93300 100644 --- a/src/instructlab/training/gpt_oss_utils_correct.py +++ b/src/instructlab/training/gpt_oss_utils_correct.py @@ -411,7 +411,9 @@ def is_gpt_oss(model_path_or_config: str | PretrainedConfig) -> bool: return getattr(model_config, "model_type", None) == "gpt_oss" -def is_known_model(model_path_or_config: str | PretrainedConfig, known_model_type: str | list[str]) -> bool: +def is_known_model( + model_path_or_config: str | PretrainedConfig, known_model_type: str | list[str] +) -> bool: """ Determine if the model is a known model. """ @@ -425,7 +427,9 @@ def is_known_model(model_path_or_config: str | PretrainedConfig, known_model_typ if isinstance(model_path_or_config, str): model_config = AutoConfig.from_pretrained(model_path_or_config) - known_model_types = [known_model_type] if isinstance(known_model_type, str) else known_model_type + known_model_types = ( + [known_model_type] if isinstance(known_model_type, str) else known_model_type + ) return getattr(model_config, "model_type", None) in known_model_types diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index 0c4740f5..b73a86f3 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -346,7 +346,7 @@ def main(args): # GPT-OSS specifically # We don't want to use use_orig_params for GPT-OSS models fsdp_should_use_orig_params = False - if m.is_gpt_oss or m.is_granitemoehybrid: # NOTE is this guard needed? + if m.is_gpt_oss or m.is_granitemoehybrid: frozen_router_params = freeze_router_params(m) if frozen_router_params: logger.info("🎯 Detected an MoE model - frozen router parameters") diff --git a/src/instructlab/training/model.py b/src/instructlab/training/model.py index 75681cb0..bb89a1c4 100644 --- a/src/instructlab/training/model.py +++ b/src/instructlab/training/model.py @@ -419,7 +419,7 @@ def compute_loss( # add the MoE auxiliary loss (currently we only support this for GPT-OSS) if ( - (self.is_gpt_oss or self.is_granitemoehybrid) # NOTE is this guard needed? + (self.is_gpt_oss or self.is_granitemoehybrid) and hasattr(output, "aux_loss") and output.aux_loss is not None ): From 34b7cd603f9d848e6f7635e6bb21c29a91b62af0 Mon Sep 17 00:00:00 2001 From: Mikio Takeuchi Date: Wed, 19 Nov 2025 19:19:07 +0900 Subject: [PATCH 5/5] address bot's comment --- src/instructlab/training/gpt_oss_utils_correct.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/src/instructlab/training/gpt_oss_utils_correct.py b/src/instructlab/training/gpt_oss_utils_correct.py index 7cf93300..a77ee15a 100644 --- a/src/instructlab/training/gpt_oss_utils_correct.py +++ b/src/instructlab/training/gpt_oss_utils_correct.py @@ -398,17 +398,7 @@ def is_gpt_oss(model_path_or_config: str | PretrainedConfig) -> bool: """ Determine if we should convert GPT-OSS format during saving. """ - if not isinstance(model_path_or_config, (PretrainedConfig, str)): - raise ValueError( - f"cannot detect model: received invalid argument of type {type(model_path_or_config)}" - ) - - # convert to config - model_config = model_path_or_config - if isinstance(model_path_or_config, str): - model_config = AutoConfig.from_pretrained(model_path_or_config) - - return getattr(model_config, "model_type", None) == "gpt_oss" + return is_known_model(model_path_or_config, "gpt_oss") def is_known_model(