From bdd431c7ffa99849ffc9c21c1f849d75f5a37855 Mon Sep 17 00:00:00 2001 From: Srivatsan Krishnan Date: Thu, 7 Aug 2025 00:26:51 -0700 Subject: [PATCH 01/28] support to expose agent config and seed hyperparameters --- src/cloudai/cli/handlers.py | 24 ++++++++- src/cloudai/models/scenario.py | 4 +- src/cloudai/models/workload.py | 95 +++++++++++++++++++++++++++++++++- 3 files changed, 120 insertions(+), 3 deletions(-) diff --git a/src/cloudai/cli/handlers.py b/src/cloudai/cli/handlers.py index c642341e8..9efb4dfb5 100644 --- a/src/cloudai/cli/handlers.py +++ b/src/cloudai/cli/handlers.py @@ -131,7 +131,29 @@ def handle_dse_job(runner: Runner, args: argparse.Namespace): ) continue - agent = agent_class(env) + agent_config = test_run.test.test_definition.agent_config + + if agent_config: + agent_kwargs = { + k: v for k, v in agent_config.model_dump().items() + if v is not None and k not in ['extra_params', 'seed_parameters'] + } + + if hasattr(agent_config, 'seed_parameters') and agent_config.seed_parameters: + action_space = env.define_action_space() + resolved_seeds = test_run.test.test_definition.resolve_seed_parameters(action_space) + if resolved_seeds: + agent_kwargs['seed_parameters'] = resolved_seeds + + agent_kwargs.update(agent_config.extra_params) + + try: + agent = agent_class(env, **agent_kwargs) + except TypeError as e: + logging.warning(f"Agent {agent_type} doesn't support some configuration parameters: {e}") + agent = agent_class(env) + else: + agent = agent_class(env) for step in range(agent.max_steps): result = agent.select_action() if result is None: diff --git a/src/cloudai/models/scenario.py b/src/cloudai/models/scenario.py index 9a568bf96..d7149e97c 100644 --- a/src/cloudai/models/scenario.py +++ b/src/cloudai/models/scenario.py @@ -22,7 +22,7 @@ from pydantic import BaseModel, ConfigDict, Field, ValidationError, field_serializer, field_validator, model_validator from cloudai.core import CmdArgs, GitRepo, NsysConfiguration, Registry, Reporter, TestRun -from cloudai.models.workload import TestDefinition +from cloudai.models.workload import AgentConfig, BOAgentConfig, TestDefinition def parse_reports_spec( @@ -91,6 +91,7 @@ class TestRunModel(BaseModel): agent: Optional[str] = None agent_steps: Optional[int] = None agent_metrics: list[str] = Field(default=["default"]) + agent_config: Optional[AgentConfig] = None def tdef_model_dump(self) -> dict: """Return a dictionary with non-None values that correspond to the test definition fields.""" @@ -101,6 +102,7 @@ def tdef_model_dump(self) -> dict: "agent": self.agent, "agent_steps": self.agent_steps, "agent_metrics": self.agent_metrics, + "agent_config": self.agent_config.model_dump() if self.agent_config else None, "extra_container_mounts": self.extra_container_mounts, "extra_env_vars": self.extra_env_vars if self.extra_env_vars else None, "cmd_args": self.cmd_args.model_dump() if self.cmd_args else None, diff --git a/src/cloudai/models/workload.py b/src/cloudai/models/workload.py index 1745ae734..27eba1e62 100644 --- a/src/cloudai/models/workload.py +++ b/src/cloudai/models/workload.py @@ -18,11 +18,37 @@ from dataclasses import dataclass from typing import Any, Dict, List, Optional, Union -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, ConfigDict, Field, field_validator from cloudai.core import GitRepo, Installable, JobStatusResult, PythonExecutable, TestRun +class AgentConfig(BaseModel): + """Base configuration class for agents used in DSE.""" + + model_config = ConfigDict(extra="forbid") + + # Common agent parameters + random_seed: Optional[int] = None + + # Allow for additional agent-specific parameters + extra_params: Dict[str, Any] = Field(default_factory=dict) + + +class BOAgentConfig(AgentConfig): + """Configuration for Bayesian Optimization agent.""" + + # BO-specific parameters + sobol_num_trials: Optional[int] = None + botorch_num_trials: Optional[int] = None + + # Seed parameters for starting optimization from known configuration + seed_parameters: Optional[Dict[str, Any]] = None + + # Allow for additional agent-specific parameters + extra_params: Dict[str, Any] = Field(default_factory=dict) + + class CmdArgs(BaseModel): """Test command arguments.""" @@ -107,6 +133,73 @@ class TestDefinition(BaseModel, ABC): agent_steps: int = 1 agent_metrics: list[str] = Field(default=["default"]) agent_reward_function: str = "inverse" + agent_config: Optional[AgentConfig] = None + + @field_validator('agent_config', mode='before') + @classmethod + def parse_agent_config(cls, v, info): + """Parse agent_config based on the agent type.""" + if v is None: + return None + + # If it's already an AgentConfig instance, return as-is + if isinstance(v, AgentConfig): + return v + + # If it's a dict, parse based on agent type + if isinstance(v, dict): + # Get the agent type from the model data + agent_type = info.data.get('agent', 'grid_search') + + # Map agent types to their config classes + agent_config_map = { + 'bo_gp': BOAgentConfig, + 'random_walker': RandomWalkerAgentConfig, + } + + # Use the appropriate config class or fall back to base AgentConfig + config_class = agent_config_map.get(agent_type, AgentConfig) + return config_class.model_validate(v) + + return v + + def resolve_seed_parameters(self, action_space: Dict[str, Any]) -> Optional[Dict[str, Any]]: + """ + Resolve seed parameters by extracting values from the action space. + + Args: + action_space: The flattened action space from cmd_args + + Returns: + Resolved seed parameters with actual values + """ + if not self.agent_config or not hasattr(self.agent_config, 'seed_parameters'): + return None + + seed_params = self.agent_config.seed_parameters + if not seed_params: + return None + + resolved = {} + for param_name, value_spec in seed_params.items(): + if param_name in action_space: + param_options = action_space[param_name] + if isinstance(param_options, list): + if isinstance(value_spec, int) and 0 <= value_spec < len(param_options): + resolved[param_name] = param_options[value_spec] + elif value_spec in param_options: + resolved[param_name] = value_spec + else: + # Default to first option if value not found + resolved[param_name] = param_options[0] + else: + # Single value parameter + resolved[param_name] = param_options + else: + # Parameter not in action space, use as-is (for backwards compatibility) + resolved[param_name] = value_spec + + return resolved @property def cmd_args_dict(self) -> Dict[str, Union[str, List[str]]]: From dbdbcb0c67060da3152810ce6ea46aafcb7f49b9 Mon Sep 17 00:00:00 2001 From: Srivatsan Krishnan Date: Thu, 7 Aug 2025 01:01:02 -0700 Subject: [PATCH 02/28] remove random walker agent config --- src/cloudai/models/workload.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/src/cloudai/models/workload.py b/src/cloudai/models/workload.py index 27eba1e62..f0a7a4247 100644 --- a/src/cloudai/models/workload.py +++ b/src/cloudai/models/workload.py @@ -142,22 +142,16 @@ def parse_agent_config(cls, v, info): if v is None: return None - # If it's already an AgentConfig instance, return as-is if isinstance(v, AgentConfig): return v - # If it's a dict, parse based on agent type if isinstance(v, dict): - # Get the agent type from the model data agent_type = info.data.get('agent', 'grid_search') - # Map agent types to their config classes agent_config_map = { - 'bo_gp': BOAgentConfig, - 'random_walker': RandomWalkerAgentConfig, + 'bo_gp': BOAgentConfig } - # Use the appropriate config class or fall back to base AgentConfig config_class = agent_config_map.get(agent_type, AgentConfig) return config_class.model_validate(v) @@ -190,13 +184,10 @@ def resolve_seed_parameters(self, action_space: Dict[str, Any]) -> Optional[Dict elif value_spec in param_options: resolved[param_name] = value_spec else: - # Default to first option if value not found resolved[param_name] = param_options[0] else: - # Single value parameter resolved[param_name] = param_options else: - # Parameter not in action space, use as-is (for backwards compatibility) resolved[param_name] = value_spec return resolved From 9b80f39772b41d705b7213bd765d1a1f430cf7cf Mon Sep 17 00:00:00 2001 From: Srivatsan Krishnan Date: Thu, 7 Aug 2025 01:10:09 -0700 Subject: [PATCH 03/28] debug logs --- src/cloudai/cli/handlers.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/cloudai/cli/handlers.py b/src/cloudai/cli/handlers.py index 9efb4dfb5..6e6f058bf 100644 --- a/src/cloudai/cli/handlers.py +++ b/src/cloudai/cli/handlers.py @@ -133,20 +133,34 @@ def handle_dse_job(runner: Runner, args: argparse.Namespace): agent_config = test_run.test.test_definition.agent_config + logging.info(f"Handler: agent_config type = {type(agent_config)}") + logging.info(f"Handler: agent_config = {agent_config}") + if agent_config: agent_kwargs = { k: v for k, v in agent_config.model_dump().items() if v is not None and k not in ['extra_params', 'seed_parameters'] } + logging.info(f"Handler: checking seed_parameters - hasattr: {hasattr(agent_config, 'seed_parameters')}") + if hasattr(agent_config, 'seed_parameters'): + logging.info(f"Handler: agent_config.seed_parameters = {agent_config.seed_parameters}") + if hasattr(agent_config, 'seed_parameters') and agent_config.seed_parameters: action_space = env.define_action_space() + logging.info(f"Handler: action_space = {action_space}") + logging.info(f"Handler: raw seed_parameters from config = {agent_config.seed_parameters}") resolved_seeds = test_run.test.test_definition.resolve_seed_parameters(action_space) + logging.info(f"Handler: resolved seed_parameters = {resolved_seeds}") if resolved_seeds: agent_kwargs['seed_parameters'] = resolved_seeds + else: + logging.info(f"Handler: No seed_parameters found or seed_parameters is None") agent_kwargs.update(agent_config.extra_params) + logging.info(f"Handler: final agent_kwargs = {agent_kwargs}") + try: agent = agent_class(env, **agent_kwargs) except TypeError as e: From b3f5c9c9aec7cf3af7529b05ee8b8df211e8c73a Mon Sep 17 00:00:00 2001 From: Srivatsan Krishnan Date: Thu, 7 Aug 2025 01:14:38 -0700 Subject: [PATCH 04/28] debug log --- src/cloudai/models/workload.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/src/cloudai/models/workload.py b/src/cloudai/models/workload.py index f0a7a4247..363ba46e4 100644 --- a/src/cloudai/models/workload.py +++ b/src/cloudai/models/workload.py @@ -139,22 +139,34 @@ class TestDefinition(BaseModel, ABC): @classmethod def parse_agent_config(cls, v, info): """Parse agent_config based on the agent type.""" + import logging + logging.info(f"Field validator called with v = {v}, type = {type(v)}") + if v is None: + logging.info("Field validator: v is None, returning None") return None if isinstance(v, AgentConfig): + logging.info("Field validator: v is already AgentConfig instance") return v if isinstance(v, dict): agent_type = info.data.get('agent', 'grid_search') + logging.info(f"Field validator: agent_type = {agent_type}") + logging.info(f"Field validator: input dict = {v}") agent_config_map = { 'bo_gp': BOAgentConfig } config_class = agent_config_map.get(agent_type, AgentConfig) - return config_class.model_validate(v) + logging.info(f"Field validator: using config_class = {config_class}") + + result = config_class.model_validate(v) + logging.info(f"Field validator: result = {result}") + return result + logging.info(f"Field validator: unexpected type {type(v)}, returning as-is") return v def resolve_seed_parameters(self, action_space: Dict[str, Any]) -> Optional[Dict[str, Any]]: From 44176d1744f501af317b03619eb289c6515a871b Mon Sep 17 00:00:00 2001 From: Srivatsan Krishnan Date: Thu, 7 Aug 2025 01:17:55 -0700 Subject: [PATCH 05/28] debug logs --- src/cloudai/models/scenario.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/cloudai/models/scenario.py b/src/cloudai/models/scenario.py index d7149e97c..9065de776 100644 --- a/src/cloudai/models/scenario.py +++ b/src/cloudai/models/scenario.py @@ -95,6 +95,15 @@ class TestRunModel(BaseModel): def tdef_model_dump(self) -> dict: """Return a dictionary with non-None values that correspond to the test definition fields.""" + import logging + + agent_config_dump = None + if self.agent_config: + agent_config_dump = self.agent_config.model_dump() + logging.info(f"tdef_model_dump: agent_config type = {type(self.agent_config)}") + logging.info(f"tdef_model_dump: agent_config = {self.agent_config}") + logging.info(f"tdef_model_dump: agent_config.model_dump() = {agent_config_dump}") + data = { "name": self.name, "description": self.description, @@ -102,7 +111,7 @@ def tdef_model_dump(self) -> dict: "agent": self.agent, "agent_steps": self.agent_steps, "agent_metrics": self.agent_metrics, - "agent_config": self.agent_config.model_dump() if self.agent_config else None, + "agent_config": agent_config_dump, "extra_container_mounts": self.extra_container_mounts, "extra_env_vars": self.extra_env_vars if self.extra_env_vars else None, "cmd_args": self.cmd_args.model_dump() if self.cmd_args else None, From 8053f3c0d3afe623e777fec89b6e0c18198cbc31 Mon Sep 17 00:00:00 2001 From: Srivatsan Krishnan Date: Thu, 7 Aug 2025 01:22:51 -0700 Subject: [PATCH 06/28] more debug logs --- src/cloudai/models/workload.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/cloudai/models/workload.py b/src/cloudai/models/workload.py index 363ba46e4..35e49770c 100644 --- a/src/cloudai/models/workload.py +++ b/src/cloudai/models/workload.py @@ -140,6 +140,15 @@ class TestDefinition(BaseModel, ABC): def parse_agent_config(cls, v, info): """Parse agent_config based on the agent type.""" import logging + import traceback + + # Add stack trace for the problematic call + if isinstance(v, dict) and v == {'random_seed': 42, 'extra_params': {}}: + logging.info(f"!!! PROBLEMATIC CALL DETECTED !!!") + logging.info(f"Stack trace:") + for line in traceback.format_stack(): + logging.info(line.strip()) + logging.info(f"Field validator called with v = {v}, type = {type(v)}") if v is None: From 90e1fbf63706b627792eff325d7d37e9c6f98b9b Mon Sep 17 00:00:00 2001 From: Srivatsan Krishnan Date: Thu, 7 Aug 2025 01:25:34 -0700 Subject: [PATCH 07/28] more debug logs --- src/cloudai/test_scenario_parser.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/src/cloudai/test_scenario_parser.py b/src/cloudai/test_scenario_parser.py index ce5f2a057..40077256c 100644 --- a/src/cloudai/test_scenario_parser.py +++ b/src/cloudai/test_scenario_parser.py @@ -234,10 +234,20 @@ def _prepare_tdef(self, test_info: TestRunModel) -> Tuple[Test, TestDefinition]: raise ValueError(f"Test '{test_info.test_name}' is not defined. Was tests directory correctly set?") test = self.test_mapping[test_info.test_name] - test_defined = test.test_definition.model_dump() - tc_defined = test_info.tdef_model_dump() - merged_data = deep_merge(test_defined, tc_defined) - test.test_definition = tp.load_test_definition(merged_data, self.strict) + test_defined = test.test_definition.model_dump() + tc_defined = test_info.tdef_model_dump() + + # Debug logging + import logging + logging.info(f"_prepare_tdef: test.test_definition.agent_config type = {type(test.test_definition.agent_config)}") + logging.info(f"_prepare_tdef: test.test_definition.agent_config = {test.test_definition.agent_config}") + logging.info(f"_prepare_tdef: test_defined agent_config = {test_defined.get('agent_config')}") + logging.info(f"_prepare_tdef: tc_defined agent_config = {tc_defined.get('agent_config')}") + + merged_data = deep_merge(test_defined, tc_defined) + logging.info(f"_prepare_tdef: merged_data agent_config = {merged_data.get('agent_config')}") + + test.test_definition = tp.load_test_definition(merged_data, self.strict) elif test_info.test_template_name: # test fully defined in the scenario test = tp._parse_data(test_info.tdef_model_dump(), self.strict) else: From 72385bd0308cb4b3c925cb2bef770142a74d28e3 Mon Sep 17 00:00:00 2001 From: Srivatsan Krishnan Date: Thu, 7 Aug 2025 01:50:40 -0700 Subject: [PATCH 08/28] remove debug logs --- src/cloudai/models/workload.py | 27 +++++++++------------------ src/cloudai/test_scenario_parser.py | 29 +++++++++++++++-------------- 2 files changed, 24 insertions(+), 32 deletions(-) diff --git a/src/cloudai/models/workload.py b/src/cloudai/models/workload.py index 35e49770c..856c00777 100644 --- a/src/cloudai/models/workload.py +++ b/src/cloudai/models/workload.py @@ -140,42 +140,33 @@ class TestDefinition(BaseModel, ABC): def parse_agent_config(cls, v, info): """Parse agent_config based on the agent type.""" import logging - import traceback - - # Add stack trace for the problematic call - if isinstance(v, dict) and v == {'random_seed': 42, 'extra_params': {}}: - logging.info(f"!!! PROBLEMATIC CALL DETECTED !!!") - logging.info(f"Stack trace:") - for line in traceback.format_stack(): - logging.info(line.strip()) - - logging.info(f"Field validator called with v = {v}, type = {type(v)}") if v is None: - logging.info("Field validator: v is None, returning None") return None if isinstance(v, AgentConfig): - logging.info("Field validator: v is already AgentConfig instance") return v if isinstance(v, dict): agent_type = info.data.get('agent', 'grid_search') - logging.info(f"Field validator: agent_type = {agent_type}") - logging.info(f"Field validator: input dict = {v}") + + # Critical debugging: Track when BO data is incomplete + if agent_type == 'bo_gp': + has_bo_fields = 'sobol_num_trials' in v or 'botorch_num_trials' in v or 'seed_parameters' in v + if not has_bo_fields: + logging.warning(f"🚨 BO agent_config missing BO fields! Input: {v}") + else: + logging.info(f"✅ BO agent_config has BO fields: {v}") agent_config_map = { 'bo_gp': BOAgentConfig } config_class = agent_config_map.get(agent_type, AgentConfig) - logging.info(f"Field validator: using config_class = {config_class}") - result = config_class.model_validate(v) - logging.info(f"Field validator: result = {result}") + return result - logging.info(f"Field validator: unexpected type {type(v)}, returning as-is") return v def resolve_seed_parameters(self, action_space: Dict[str, Any]) -> Optional[Dict[str, Any]]: diff --git a/src/cloudai/test_scenario_parser.py b/src/cloudai/test_scenario_parser.py index 40077256c..207c7fb91 100644 --- a/src/cloudai/test_scenario_parser.py +++ b/src/cloudai/test_scenario_parser.py @@ -234,20 +234,21 @@ def _prepare_tdef(self, test_info: TestRunModel) -> Tuple[Test, TestDefinition]: raise ValueError(f"Test '{test_info.test_name}' is not defined. Was tests directory correctly set?") test = self.test_mapping[test_info.test_name] - test_defined = test.test_definition.model_dump() - tc_defined = test_info.tdef_model_dump() - - # Debug logging - import logging - logging.info(f"_prepare_tdef: test.test_definition.agent_config type = {type(test.test_definition.agent_config)}") - logging.info(f"_prepare_tdef: test.test_definition.agent_config = {test.test_definition.agent_config}") - logging.info(f"_prepare_tdef: test_defined agent_config = {test_defined.get('agent_config')}") - logging.info(f"_prepare_tdef: tc_defined agent_config = {tc_defined.get('agent_config')}") - - merged_data = deep_merge(test_defined, tc_defined) - logging.info(f"_prepare_tdef: merged_data agent_config = {merged_data.get('agent_config')}") - - test.test_definition = tp.load_test_definition(merged_data, self.strict) + test_defined = test.test_definition.model_dump(exclude_none=False) + tc_defined = test_info.tdef_model_dump() + + # Focused debugging: Track agent_config through merge + import logging + if test.test_definition.agent_config: + logging.info(f"🔍 MERGE DEBUG: Original agent_config type: {type(test.test_definition.agent_config)}") + logging.info(f"🔍 MERGE DEBUG: test_defined agent_config: {test_defined.get('agent_config')}") + + merged_data = deep_merge(test_defined, tc_defined) + + if 'agent_config' in merged_data: + logging.info(f"🔍 MERGE DEBUG: merged_data agent_config: {merged_data.get('agent_config')}") + + test.test_definition = tp.load_test_definition(merged_data, self.strict) elif test_info.test_template_name: # test fully defined in the scenario test = tp._parse_data(test_info.tdef_model_dump(), self.strict) else: From 6c720041015900650292f0a48df0e1b334b249b3 Mon Sep 17 00:00:00 2001 From: Srivatsan Krishnan Date: Thu, 7 Aug 2025 01:58:19 -0700 Subject: [PATCH 09/28] preserve agent_configs --- src/cloudai/models/workload.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/cloudai/models/workload.py b/src/cloudai/models/workload.py index 856c00777..e79a885e0 100644 --- a/src/cloudai/models/workload.py +++ b/src/cloudai/models/workload.py @@ -48,6 +48,12 @@ class BOAgentConfig(AgentConfig): # Allow for additional agent-specific parameters extra_params: Dict[str, Any] = Field(default_factory=dict) + def model_dump(self, **kwargs): + """Override model_dump to ensure all BO fields are preserved.""" + # Force exclude_none=False to preserve all fields + kwargs['exclude_none'] = False + return super().model_dump(**kwargs) + class CmdArgs(BaseModel): """Test command arguments.""" From 37f7e6b4e3b474363692e7e31038a2ddc2bfc544 Mon Sep 17 00:00:00 2001 From: Srivatsan Krishnan Date: Thu, 7 Aug 2025 14:36:09 -0700 Subject: [PATCH 10/28] more debugging --- src/cloudai/models/scenario.py | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/src/cloudai/models/scenario.py b/src/cloudai/models/scenario.py index 9065de776..a70da3867 100644 --- a/src/cloudai/models/scenario.py +++ b/src/cloudai/models/scenario.py @@ -93,6 +93,40 @@ class TestRunModel(BaseModel): agent_metrics: list[str] = Field(default=["default"]) agent_config: Optional[AgentConfig] = None + @field_validator('agent_config', mode='before') + @classmethod + def parse_agent_config(cls, v, info): + """Parse agent_config based on the agent type.""" + import logging + + if v is None: + return None + + if isinstance(v, AgentConfig): + return v + + if isinstance(v, dict): + agent_type = info.data.get('agent', 'grid_search') + + # Critical debugging: Track when BO data is incomplete + if agent_type == 'bo_gp': + has_bo_fields = 'sobol_num_trials' in v or 'botorch_num_trials' in v or 'seed_parameters' in v + if not has_bo_fields: + logging.warning(f"SCENARIO BO agent_config missing BO fields! Input: {v}") + else: + logging.info(f"SCENARIO BO agent_config has BO fields: {v}") + + agent_config_map = { + 'bo_gp': BOAgentConfig + } + + config_class = agent_config_map.get(agent_type, AgentConfig) + result = config_class.model_validate(v) + + return result + + return v + def tdef_model_dump(self) -> dict: """Return a dictionary with non-None values that correspond to the test definition fields.""" import logging From 7a099f9ac769c0dfe7bb3f881ae46f3d026e0eed Mon Sep 17 00:00:00 2001 From: Srivatsan Krishnan Date: Thu, 7 Aug 2025 14:40:59 -0700 Subject: [PATCH 11/28] debugging --- src/cloudai/models/workload.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/cloudai/models/workload.py b/src/cloudai/models/workload.py index e79a885e0..fe7ab523c 100644 --- a/src/cloudai/models/workload.py +++ b/src/cloudai/models/workload.py @@ -48,11 +48,20 @@ class BOAgentConfig(AgentConfig): # Allow for additional agent-specific parameters extra_params: Dict[str, Any] = Field(default_factory=dict) + def __init__(self, **data): + super().__init__(**data) + import logging + logging.info(f"🆕 BOAgentConfig created with data: {data}") + logging.info(f"🆕 Final BOAgentConfig state: sobol={self.sobol_num_trials}, botorch={self.botorch_num_trials}, seeds={self.seed_parameters}") + def model_dump(self, **kwargs): """Override model_dump to ensure all BO fields are preserved.""" # Force exclude_none=False to preserve all fields kwargs['exclude_none'] = False - return super().model_dump(**kwargs) + result = super().model_dump(**kwargs) + import logging + logging.info(f"📤 BOAgentConfig.model_dump() called, result: {result}") + return result class CmdArgs(BaseModel): From eddfec7a114c9f88e623ddef0c49e2dd981efabc Mon Sep 17 00:00:00 2001 From: Srivatsan Krishnan Date: Thu, 7 Aug 2025 16:43:52 -0700 Subject: [PATCH 12/28] debug --- src/cloudai/test_scenario_parser.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/cloudai/test_scenario_parser.py b/src/cloudai/test_scenario_parser.py index 207c7fb91..b50cea51f 100644 --- a/src/cloudai/test_scenario_parser.py +++ b/src/cloudai/test_scenario_parser.py @@ -243,6 +243,20 @@ def _prepare_tdef(self, test_info: TestRunModel) -> Tuple[Test, TestDefinition]: logging.info(f"🔍 MERGE DEBUG: Original agent_config type: {type(test.test_definition.agent_config)}") logging.info(f"🔍 MERGE DEBUG: test_defined agent_config: {test_defined.get('agent_config')}") + # CRITICAL FIX: If tc_defined has a more complete agent_config, prioritize it + test_agent_config = test_defined.get('agent_config') + tc_agent_config = tc_defined.get('agent_config') + + if test_agent_config and tc_agent_config: + # Count non-None values in both configs + test_non_none = sum(1 for v in test_agent_config.values() if v is not None) + tc_non_none = sum(1 for v in tc_agent_config.values() if v is not None) + + # If tc_defined has more complete data, use it instead of merging + if tc_non_none > test_non_none: + logging.info(f"PRIORITIZING tc_agent_config (non-None: {tc_non_none}) over test_agent_config (non-None: {test_non_none})") + test_defined['agent_config'] = tc_agent_config + merged_data = deep_merge(test_defined, tc_defined) if 'agent_config' in merged_data: From 9c35c2613ffbc2c3376654d0db7e7e64e7017e78 Mon Sep 17 00:00:00 2001 From: Srivatsan Krishnan Date: Thu, 7 Aug 2025 16:55:48 -0700 Subject: [PATCH 13/28] revert changes/debug logs --- src/cloudai/models/scenario.py | 10 +--------- src/cloudai/models/workload.py | 6 ++---- src/cloudai/test_scenario_parser.py | 26 +------------------------- 3 files changed, 4 insertions(+), 38 deletions(-) diff --git a/src/cloudai/models/scenario.py b/src/cloudai/models/scenario.py index a70da3867..a26e55ddb 100644 --- a/src/cloudai/models/scenario.py +++ b/src/cloudai/models/scenario.py @@ -129,15 +129,7 @@ def parse_agent_config(cls, v, info): def tdef_model_dump(self) -> dict: """Return a dictionary with non-None values that correspond to the test definition fields.""" - import logging - - agent_config_dump = None - if self.agent_config: - agent_config_dump = self.agent_config.model_dump() - logging.info(f"tdef_model_dump: agent_config type = {type(self.agent_config)}") - logging.info(f"tdef_model_dump: agent_config = {self.agent_config}") - logging.info(f"tdef_model_dump: agent_config.model_dump() = {agent_config_dump}") - + agent_config_dump = self.agent_config.model_dump() if self.agent_config else None data = { "name": self.name, "description": self.description, diff --git a/src/cloudai/models/workload.py b/src/cloudai/models/workload.py index fe7ab523c..1988f3928 100644 --- a/src/cloudai/models/workload.py +++ b/src/cloudai/models/workload.py @@ -154,7 +154,6 @@ class TestDefinition(BaseModel, ABC): @classmethod def parse_agent_config(cls, v, info): """Parse agent_config based on the agent type.""" - import logging if v is None: return None @@ -165,13 +164,12 @@ def parse_agent_config(cls, v, info): if isinstance(v, dict): agent_type = info.data.get('agent', 'grid_search') - # Critical debugging: Track when BO data is incomplete if agent_type == 'bo_gp': has_bo_fields = 'sobol_num_trials' in v or 'botorch_num_trials' in v or 'seed_parameters' in v if not has_bo_fields: - logging.warning(f"🚨 BO agent_config missing BO fields! Input: {v}") + pass else: - logging.info(f"✅ BO agent_config has BO fields: {v}") + pass agent_config_map = { 'bo_gp': BOAgentConfig diff --git a/src/cloudai/test_scenario_parser.py b/src/cloudai/test_scenario_parser.py index b50cea51f..890e8d647 100644 --- a/src/cloudai/test_scenario_parser.py +++ b/src/cloudai/test_scenario_parser.py @@ -237,36 +237,12 @@ def _prepare_tdef(self, test_info: TestRunModel) -> Tuple[Test, TestDefinition]: test_defined = test.test_definition.model_dump(exclude_none=False) tc_defined = test_info.tdef_model_dump() - # Focused debugging: Track agent_config through merge - import logging - if test.test_definition.agent_config: - logging.info(f"🔍 MERGE DEBUG: Original agent_config type: {type(test.test_definition.agent_config)}") - logging.info(f"🔍 MERGE DEBUG: test_defined agent_config: {test_defined.get('agent_config')}") - - # CRITICAL FIX: If tc_defined has a more complete agent_config, prioritize it - test_agent_config = test_defined.get('agent_config') - tc_agent_config = tc_defined.get('agent_config') - - if test_agent_config and tc_agent_config: - # Count non-None values in both configs - test_non_none = sum(1 for v in test_agent_config.values() if v is not None) - tc_non_none = sum(1 for v in tc_agent_config.values() if v is not None) - - # If tc_defined has more complete data, use it instead of merging - if tc_non_none > test_non_none: - logging.info(f"PRIORITIZING tc_agent_config (non-None: {tc_non_none}) over test_agent_config (non-None: {test_non_none})") - test_defined['agent_config'] = tc_agent_config - merged_data = deep_merge(test_defined, tc_defined) - if 'agent_config' in merged_data: - logging.info(f"🔍 MERGE DEBUG: merged_data agent_config: {merged_data.get('agent_config')}") - test.test_definition = tp.load_test_definition(merged_data, self.strict) - elif test_info.test_template_name: # test fully defined in the scenario + elif test_info.test_template_name: test = tp._parse_data(test_info.tdef_model_dump(), self.strict) else: - # this should never happen, because we check for this in the modelvalidator raise ValueError( f"Cannot configure test case '{test_info.id}' with both 'test_name' and 'test_template_name'." ) From dac116df189cc54a49e9fead331d4660fb3c61bc Mon Sep 17 00:00:00 2001 From: Srivatsan Krishnan Date: Thu, 7 Aug 2025 17:39:42 -0700 Subject: [PATCH 14/28] fix for param seeding --- src/cloudai/models/workload.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/cloudai/models/workload.py b/src/cloudai/models/workload.py index 1988f3928..db8fd6f62 100644 --- a/src/cloudai/models/workload.py +++ b/src/cloudai/models/workload.py @@ -204,11 +204,16 @@ def resolve_seed_parameters(self, action_space: Dict[str, Any]) -> Optional[Dict if param_name in action_space: param_options = action_space[param_name] if isinstance(param_options, list): - if isinstance(value_spec, int) and 0 <= value_spec < len(param_options): - resolved[param_name] = param_options[value_spec] - elif value_spec in param_options: + # First, try direct value match – this avoids ambiguity when the desired + # literal value is itself an integer that could also be interpreted as an + # index (e.g., 2 in [1, 2, 4]). Only if the value is not present in the + # list do we fall back to interpreting it as an index. + if value_spec in param_options: resolved[param_name] = value_spec + elif isinstance(value_spec, int) and 0 <= value_spec < len(param_options): + resolved[param_name] = param_options[value_spec] else: + # As a last resort, pick the first option to guarantee a valid seed. resolved[param_name] = param_options[0] else: resolved[param_name] = param_options From 2ae0261402f03bbbd8911894d54248231d6b68aa Mon Sep 17 00:00:00 2001 From: Srivatsan Krishnan Date: Thu, 7 Aug 2025 17:50:38 -0700 Subject: [PATCH 15/28] more fixes --- src/cloudai/models/scenario.py | 28 +++++++++++----------------- src/cloudai/models/workload.py | 25 +++++++++---------------- 2 files changed, 20 insertions(+), 33 deletions(-) diff --git a/src/cloudai/models/scenario.py b/src/cloudai/models/scenario.py index a26e55ddb..4f08174f9 100644 --- a/src/cloudai/models/scenario.py +++ b/src/cloudai/models/scenario.py @@ -106,24 +106,18 @@ def parse_agent_config(cls, v, info): return v if isinstance(v, dict): - agent_type = info.data.get('agent', 'grid_search') + # Check for BO-specific fields directly instead of relying on agent field + # since field validation order means agent might not be available yet + has_bo_fields = {'sobol_num_trials', 'botorch_num_trials', 'seed_parameters'} & v.keys() - # Critical debugging: Track when BO data is incomplete - if agent_type == 'bo_gp': - has_bo_fields = 'sobol_num_trials' in v or 'botorch_num_trials' in v or 'seed_parameters' in v - if not has_bo_fields: - logging.warning(f"SCENARIO BO agent_config missing BO fields! Input: {v}") - else: - logging.info(f"SCENARIO BO agent_config has BO fields: {v}") - - agent_config_map = { - 'bo_gp': BOAgentConfig - } - - config_class = agent_config_map.get(agent_type, AgentConfig) - result = config_class.model_validate(v) - - return result + if has_bo_fields: + logging.info(f"SCENARIO BO agent_config has BO fields: {v}") + # Use BOAgentConfig when BO-specific fields are present + return BOAgentConfig.model_validate(v) + else: + logging.warning(f"SCENARIO agent_config missing BO fields! Input: {v}") + # Fall back to base AgentConfig for other cases + return AgentConfig.model_validate(v) return v diff --git a/src/cloudai/models/workload.py b/src/cloudai/models/workload.py index db8fd6f62..7f4708dd9 100644 --- a/src/cloudai/models/workload.py +++ b/src/cloudai/models/workload.py @@ -162,23 +162,16 @@ def parse_agent_config(cls, v, info): return v if isinstance(v, dict): - agent_type = info.data.get('agent', 'grid_search') + # Check for BO-specific fields directly instead of relying on agent field + # since field validation order means agent might not be available yet + has_bo_fields = {'sobol_num_trials', 'botorch_num_trials', 'seed_parameters'} & v.keys() - if agent_type == 'bo_gp': - has_bo_fields = 'sobol_num_trials' in v or 'botorch_num_trials' in v or 'seed_parameters' in v - if not has_bo_fields: - pass - else: - pass - - agent_config_map = { - 'bo_gp': BOAgentConfig - } - - config_class = agent_config_map.get(agent_type, AgentConfig) - result = config_class.model_validate(v) - - return result + if has_bo_fields: + # Use BOAgentConfig when BO-specific fields are present + return BOAgentConfig.model_validate(v) + else: + # Fall back to base AgentConfig for other cases + return AgentConfig.model_validate(v) return v From 0af7504727c45fe06e3c1557b90450b09078d63d Mon Sep 17 00:00:00 2001 From: Srivatsan Krishnan Date: Thu, 7 Aug 2025 17:55:49 -0700 Subject: [PATCH 16/28] add bo agent specific fields --- src/cloudai/models/scenario.py | 7 +++++-- src/cloudai/models/workload.py | 12 ++++++++++-- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/src/cloudai/models/scenario.py b/src/cloudai/models/scenario.py index 4f08174f9..ed69930b9 100644 --- a/src/cloudai/models/scenario.py +++ b/src/cloudai/models/scenario.py @@ -110,9 +110,12 @@ def parse_agent_config(cls, v, info): # since field validation order means agent might not be available yet has_bo_fields = {'sobol_num_trials', 'botorch_num_trials', 'seed_parameters'} & v.keys() - if has_bo_fields: + # Also check for agent_type discriminator field + is_bo_agent = v.get('agent_type') == 'bo_gp' + + if has_bo_fields or is_bo_agent: logging.info(f"SCENARIO BO agent_config has BO fields: {v}") - # Use BOAgentConfig when BO-specific fields are present + # Use BOAgentConfig when BO-specific fields are present or agent_type indicates BO return BOAgentConfig.model_validate(v) else: logging.warning(f"SCENARIO agent_config missing BO fields! Input: {v}") diff --git a/src/cloudai/models/workload.py b/src/cloudai/models/workload.py index 7f4708dd9..8e68c8377 100644 --- a/src/cloudai/models/workload.py +++ b/src/cloudai/models/workload.py @@ -38,6 +38,9 @@ class AgentConfig(BaseModel): class BOAgentConfig(AgentConfig): """Configuration for Bayesian Optimization agent.""" + # Add discriminator field to identify this as a BO config + agent_type: str = "bo_gp" + # BO-specific parameters sobol_num_trials: Optional[int] = None botorch_num_trials: Optional[int] = None @@ -59,6 +62,8 @@ def model_dump(self, **kwargs): # Force exclude_none=False to preserve all fields kwargs['exclude_none'] = False result = super().model_dump(**kwargs) + # Ensure agent_type is always included to identify this as BO config + result['agent_type'] = self.agent_type import logging logging.info(f"📤 BOAgentConfig.model_dump() called, result: {result}") return result @@ -166,8 +171,11 @@ def parse_agent_config(cls, v, info): # since field validation order means agent might not be available yet has_bo_fields = {'sobol_num_trials', 'botorch_num_trials', 'seed_parameters'} & v.keys() - if has_bo_fields: - # Use BOAgentConfig when BO-specific fields are present + # Also check for agent_type discriminator field + is_bo_agent = v.get('agent_type') == 'bo_gp' + + if has_bo_fields or is_bo_agent: + # Use BOAgentConfig when BO-specific fields are present or agent_type indicates BO return BOAgentConfig.model_validate(v) else: # Fall back to base AgentConfig for other cases From b6cb985cc9a27e7265ec6ab444414953ecaedd58 Mon Sep 17 00:00:00 2001 From: Srivatsan Krishnan Date: Thu, 7 Aug 2025 18:04:56 -0700 Subject: [PATCH 17/28] debug messages --- src/cloudai/models/workload.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/cloudai/models/workload.py b/src/cloudai/models/workload.py index 8e68c8377..3c28ccb9c 100644 --- a/src/cloudai/models/workload.py +++ b/src/cloudai/models/workload.py @@ -52,10 +52,14 @@ class BOAgentConfig(AgentConfig): extra_params: Dict[str, Any] = Field(default_factory=dict) def __init__(self, **data): + # Ensure agent_type is always set even if not in input data + if 'agent_type' not in data: + data['agent_type'] = 'bo_gp' super().__init__(**data) import logging logging.info(f"🆕 BOAgentConfig created with data: {data}") logging.info(f"🆕 Final BOAgentConfig state: sobol={self.sobol_num_trials}, botorch={self.botorch_num_trials}, seeds={self.seed_parameters}") + logging.info(f"🆕 BOAgentConfig agent_type: {self.agent_type}") def model_dump(self, **kwargs): """Override model_dump to ensure all BO fields are preserved.""" From 2efd3bbe652a3cb7e9506dc2a40bc256795be952 Mon Sep 17 00:00:00 2001 From: Srivatsan Krishnan Date: Thu, 7 Aug 2025 18:08:40 -0700 Subject: [PATCH 18/28] more debug logs --- src/cloudai/test_scenario_parser.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/cloudai/test_scenario_parser.py b/src/cloudai/test_scenario_parser.py index 890e8d647..7362cc242 100644 --- a/src/cloudai/test_scenario_parser.py +++ b/src/cloudai/test_scenario_parser.py @@ -237,8 +237,14 @@ def _prepare_tdef(self, test_info: TestRunModel) -> Tuple[Test, TestDefinition]: test_defined = test.test_definition.model_dump(exclude_none=False) tc_defined = test_info.tdef_model_dump() + import logging + logging.info(f"🔄 MERGE test_defined agent_config: {test_defined.get('agent_config', 'NOT_PRESENT')}") + logging.info(f"🔄 MERGE tc_defined agent_config: {tc_defined.get('agent_config', 'NOT_PRESENT')}") + merged_data = deep_merge(test_defined, tc_defined) + logging.info(f"🔄 MERGE merged agent_config: {merged_data.get('agent_config', 'NOT_PRESENT')}") + test.test_definition = tp.load_test_definition(merged_data, self.strict) elif test_info.test_template_name: test = tp._parse_data(test_info.tdef_model_dump(), self.strict) From d39784ff6918ba12b3f698ebd8f1d6dce7bd7816 Mon Sep 17 00:00:00 2001 From: Srivatsan Krishnan Date: Thu, 7 Aug 2025 22:43:24 -0700 Subject: [PATCH 19/28] clean up and remove debug logging --- src/cloudai/cli/handlers.py | 16 +--------------- src/cloudai/models/scenario.py | 12 ++---------- src/cloudai/models/workload.py | 13 +------------ src/cloudai/test_scenario_parser.py | 6 ------ 4 files changed, 4 insertions(+), 43 deletions(-) diff --git a/src/cloudai/cli/handlers.py b/src/cloudai/cli/handlers.py index 6e6f058bf..af4f4f655 100644 --- a/src/cloudai/cli/handlers.py +++ b/src/cloudai/cli/handlers.py @@ -133,34 +133,20 @@ def handle_dse_job(runner: Runner, args: argparse.Namespace): agent_config = test_run.test.test_definition.agent_config - logging.info(f"Handler: agent_config type = {type(agent_config)}") - logging.info(f"Handler: agent_config = {agent_config}") - if agent_config: agent_kwargs = { k: v for k, v in agent_config.model_dump().items() - if v is not None and k not in ['extra_params', 'seed_parameters'] + if v is not None and k not in ['extra_params', 'seed_parameters', 'agent_type'] } - logging.info(f"Handler: checking seed_parameters - hasattr: {hasattr(agent_config, 'seed_parameters')}") - if hasattr(agent_config, 'seed_parameters'): - logging.info(f"Handler: agent_config.seed_parameters = {agent_config.seed_parameters}") - if hasattr(agent_config, 'seed_parameters') and agent_config.seed_parameters: action_space = env.define_action_space() - logging.info(f"Handler: action_space = {action_space}") - logging.info(f"Handler: raw seed_parameters from config = {agent_config.seed_parameters}") resolved_seeds = test_run.test.test_definition.resolve_seed_parameters(action_space) - logging.info(f"Handler: resolved seed_parameters = {resolved_seeds}") if resolved_seeds: agent_kwargs['seed_parameters'] = resolved_seeds - else: - logging.info(f"Handler: No seed_parameters found or seed_parameters is None") agent_kwargs.update(agent_config.extra_params) - logging.info(f"Handler: final agent_kwargs = {agent_kwargs}") - try: agent = agent_class(env, **agent_kwargs) except TypeError as e: diff --git a/src/cloudai/models/scenario.py b/src/cloudai/models/scenario.py index ed69930b9..b93ad1621 100644 --- a/src/cloudai/models/scenario.py +++ b/src/cloudai/models/scenario.py @@ -17,7 +17,7 @@ from __future__ import annotations from pathlib import Path -from typing import Any, Literal, Optional +from typing import Any, Literal, Optional, Union from pydantic import BaseModel, ConfigDict, Field, ValidationError, field_serializer, field_validator, model_validator @@ -91,13 +91,12 @@ class TestRunModel(BaseModel): agent: Optional[str] = None agent_steps: Optional[int] = None agent_metrics: list[str] = Field(default=["default"]) - agent_config: Optional[AgentConfig] = None + agent_config: Optional[Union[AgentConfig, BOAgentConfig]] = None @field_validator('agent_config', mode='before') @classmethod def parse_agent_config(cls, v, info): """Parse agent_config based on the agent type.""" - import logging if v is None: return None @@ -106,20 +105,13 @@ def parse_agent_config(cls, v, info): return v if isinstance(v, dict): - # Check for BO-specific fields directly instead of relying on agent field - # since field validation order means agent might not be available yet has_bo_fields = {'sobol_num_trials', 'botorch_num_trials', 'seed_parameters'} & v.keys() - # Also check for agent_type discriminator field is_bo_agent = v.get('agent_type') == 'bo_gp' if has_bo_fields or is_bo_agent: - logging.info(f"SCENARIO BO agent_config has BO fields: {v}") - # Use BOAgentConfig when BO-specific fields are present or agent_type indicates BO return BOAgentConfig.model_validate(v) else: - logging.warning(f"SCENARIO agent_config missing BO fields! Input: {v}") - # Fall back to base AgentConfig for other cases return AgentConfig.model_validate(v) return v diff --git a/src/cloudai/models/workload.py b/src/cloudai/models/workload.py index 3c28ccb9c..27f6cfbe4 100644 --- a/src/cloudai/models/workload.py +++ b/src/cloudai/models/workload.py @@ -56,10 +56,6 @@ def __init__(self, **data): if 'agent_type' not in data: data['agent_type'] = 'bo_gp' super().__init__(**data) - import logging - logging.info(f"🆕 BOAgentConfig created with data: {data}") - logging.info(f"🆕 Final BOAgentConfig state: sobol={self.sobol_num_trials}, botorch={self.botorch_num_trials}, seeds={self.seed_parameters}") - logging.info(f"🆕 BOAgentConfig agent_type: {self.agent_type}") def model_dump(self, **kwargs): """Override model_dump to ensure all BO fields are preserved.""" @@ -68,8 +64,6 @@ def model_dump(self, **kwargs): result = super().model_dump(**kwargs) # Ensure agent_type is always included to identify this as BO config result['agent_type'] = self.agent_type - import logging - logging.info(f"📤 BOAgentConfig.model_dump() called, result: {result}") return result @@ -157,7 +151,7 @@ class TestDefinition(BaseModel, ABC): agent_steps: int = 1 agent_metrics: list[str] = Field(default=["default"]) agent_reward_function: str = "inverse" - agent_config: Optional[AgentConfig] = None + agent_config: Optional[Union[AgentConfig, BOAgentConfig]] = None @field_validator('agent_config', mode='before') @classmethod @@ -209,16 +203,11 @@ def resolve_seed_parameters(self, action_space: Dict[str, Any]) -> Optional[Dict if param_name in action_space: param_options = action_space[param_name] if isinstance(param_options, list): - # First, try direct value match – this avoids ambiguity when the desired - # literal value is itself an integer that could also be interpreted as an - # index (e.g., 2 in [1, 2, 4]). Only if the value is not present in the - # list do we fall back to interpreting it as an index. if value_spec in param_options: resolved[param_name] = value_spec elif isinstance(value_spec, int) and 0 <= value_spec < len(param_options): resolved[param_name] = param_options[value_spec] else: - # As a last resort, pick the first option to guarantee a valid seed. resolved[param_name] = param_options[0] else: resolved[param_name] = param_options diff --git a/src/cloudai/test_scenario_parser.py b/src/cloudai/test_scenario_parser.py index 7362cc242..890e8d647 100644 --- a/src/cloudai/test_scenario_parser.py +++ b/src/cloudai/test_scenario_parser.py @@ -237,14 +237,8 @@ def _prepare_tdef(self, test_info: TestRunModel) -> Tuple[Test, TestDefinition]: test_defined = test.test_definition.model_dump(exclude_none=False) tc_defined = test_info.tdef_model_dump() - import logging - logging.info(f"🔄 MERGE test_defined agent_config: {test_defined.get('agent_config', 'NOT_PRESENT')}") - logging.info(f"🔄 MERGE tc_defined agent_config: {tc_defined.get('agent_config', 'NOT_PRESENT')}") - merged_data = deep_merge(test_defined, tc_defined) - logging.info(f"🔄 MERGE merged agent_config: {merged_data.get('agent_config', 'NOT_PRESENT')}") - test.test_definition = tp.load_test_definition(merged_data, self.strict) elif test_info.test_template_name: test = tp._parse_data(test_info.tdef_model_dump(), self.strict) From f59e20c892515d46722499cedc0aa7a200ca2c90 Mon Sep 17 00:00:00 2001 From: Srivatsan Krishnan Date: Mon, 11 Aug 2025 17:14:32 -0700 Subject: [PATCH 20/28] Calculate the num devices dynamically (GB200 case) --- src/cloudai/workloads/nemo_run/nemo_run.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/cloudai/workloads/nemo_run/nemo_run.py b/src/cloudai/workloads/nemo_run/nemo_run.py index b0094ea9b..97f1950dc 100644 --- a/src/cloudai/workloads/nemo_run/nemo_run.py +++ b/src/cloudai/workloads/nemo_run/nemo_run.py @@ -82,6 +82,7 @@ class Trainer(BaseModel): max_steps: Union[int, List[int]] = 100 val_check_interval: Union[int, float, list[Union[int, float]]] = 1000 num_nodes: Optional[int] = None # sweeps are done via TestRun.num_nodes + num_devices: Optional[int] = None strategy: TrainerStrategy = Field(default_factory=TrainerStrategy) plugins: Optional[Plugin] = None callbacks: Optional[Union[str, list[str]]] = None @@ -150,7 +151,7 @@ def constraint_check(self, tr: TestRun) -> bool: pp = cast(int, self.cmd_args.trainer.strategy.pipeline_model_parallel_size) cp = cast(int, self.cmd_args.trainer.strategy.context_parallel_size) vp = cast(Optional[int], self.cmd_args.trainer.strategy.virtual_pipeline_model_parallel_size) - num_gpus = tr.nnodes * 8 + num_gpus = tr.nnodes * (self.cmd_args.trainer.num_devices if self.cmd_args.trainer.num_devices else 8) num_layers = cast(int, self.cmd_args.num_layers) dp = num_gpus // (tp * pp * cp) mbs = cast(int, self.cmd_args.data.micro_batch_size) From 620fa9f5dbb2a9b5a0f3f5dff12fd7dfa2160361 Mon Sep 17 00:00:00 2001 From: Srivatsan Krishnan Date: Mon, 11 Aug 2025 17:22:04 -0700 Subject: [PATCH 21/28] typo (trainer.num_devices) --> trainer.devices --- src/cloudai/workloads/nemo_run/nemo_run.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/cloudai/workloads/nemo_run/nemo_run.py b/src/cloudai/workloads/nemo_run/nemo_run.py index 97f1950dc..8be9a7b38 100644 --- a/src/cloudai/workloads/nemo_run/nemo_run.py +++ b/src/cloudai/workloads/nemo_run/nemo_run.py @@ -82,7 +82,7 @@ class Trainer(BaseModel): max_steps: Union[int, List[int]] = 100 val_check_interval: Union[int, float, list[Union[int, float]]] = 1000 num_nodes: Optional[int] = None # sweeps are done via TestRun.num_nodes - num_devices: Optional[int] = None + devices: Optional[int] = None strategy: TrainerStrategy = Field(default_factory=TrainerStrategy) plugins: Optional[Plugin] = None callbacks: Optional[Union[str, list[str]]] = None @@ -151,7 +151,7 @@ def constraint_check(self, tr: TestRun) -> bool: pp = cast(int, self.cmd_args.trainer.strategy.pipeline_model_parallel_size) cp = cast(int, self.cmd_args.trainer.strategy.context_parallel_size) vp = cast(Optional[int], self.cmd_args.trainer.strategy.virtual_pipeline_model_parallel_size) - num_gpus = tr.nnodes * (self.cmd_args.trainer.num_devices if self.cmd_args.trainer.num_devices else 8) + num_gpus = tr.nnodes * (self.cmd_args.trainer.devices if self.cmd_args.trainer.devices else 8) num_layers = cast(int, self.cmd_args.num_layers) dp = num_gpus // (tp * pp * cp) mbs = cast(int, self.cmd_args.data.micro_batch_size) From 7f00782497996c694ffe74038376c8d6e2cc4e20 Mon Sep 17 00:00:00 2001 From: Srivatsan Krishnan Date: Mon, 11 Aug 2025 23:33:44 -0700 Subject: [PATCH 22/28] seperate reward value for failed constraint check --- src/cloudai/configurator/cloudai_gym.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cloudai/configurator/cloudai_gym.py b/src/cloudai/configurator/cloudai_gym.py index d27b19e1e..92a5125b9 100644 --- a/src/cloudai/configurator/cloudai_gym.py +++ b/src/cloudai/configurator/cloudai_gym.py @@ -102,7 +102,7 @@ def step(self, action: Any) -> Tuple[list, float, bool, dict]: if not self.test_run.test.test_definition.constraint_check(self.test_run): logging.info("Constraint check failed. Skipping step.") - return [-1.0], -1.0, True, {} + return [-1.0], -1e-6, True, {} logging.info(f"Running step {self.test_run.step} with action {action}") new_tr = copy.deepcopy(self.test_run) From 262d1a1fd6cdf7eeaf3bf2824a3bbece20e3ed15 Mon Sep 17 00:00:00 2001 From: Srivatsan Krishnan Date: Wed, 13 Aug 2025 18:24:42 -0700 Subject: [PATCH 23/28] 405b refactor --- .../workloads/nemo_run/cloudai_nemorun.py | 143 ++++++------------ 1 file changed, 44 insertions(+), 99 deletions(-) diff --git a/src/cloudai/workloads/nemo_run/cloudai_nemorun.py b/src/cloudai/workloads/nemo_run/cloudai_nemorun.py index 9e725ae20..6b89eff9c 100644 --- a/src/cloudai/workloads/nemo_run/cloudai_nemorun.py +++ b/src/cloudai/workloads/nemo_run/cloudai_nemorun.py @@ -20,6 +20,8 @@ import lightning.pytorch as pl import nemo_run as run import torch +import fiddle as fdl +import fiddle.dataclasses as fdl_dc from lightning.pytorch.loggers import TensorBoardLogger from lightning.pytorch.loggers.wandb import WandbLogger from megatron.core.distributed import DistributedDataParallelConfig @@ -32,6 +34,7 @@ from nemo.collections.llm.gpt.model.llama import Llama3Config8B, Llama3Config70B, Llama31Config405B, LlamaModel from nemo.collections.llm.gpt.model.nemotron import Nemotron4Config15B, Nemotron4Config340B, NemotronModel from nemo.collections.llm.recipes.nemotron3_8b import pretrain_recipe as nemotron3_8b_recipe +from nemo.collections.llm.recipes.llama31_405 import pretrain_recipe as llama31_405_pretrain_recipe from nemo.collections.llm.recipes.tp_overlap_configs.userbuffers import ( BulkOverlapCfg, PipelineOverlapCfg, @@ -618,6 +621,25 @@ def get_tp_overlap_config(): return tp_overlap_cfg, tp_comm_overlap +# Find the index of MegatronCommOverlapCallback in an existing callbacks list +def get_comm_overlap_callback_idx(callbacks: list) -> Optional[int]: + for idx, cb in enumerate(callbacks): + fn_or_cls = getattr(cb, "__fn_or_cls__", None) + if fn_or_cls is MegatronCommOverlapCallback: + return idx + target = getattr(cb, "target", None) or getattr(cb, "cls", None) + if target is MegatronCommOverlapCallback: + return idx + return None + + +# Convert dataclass-based TP overlap cfgs to run.Config when needed +def to_run_config(obj): + if hasattr(obj, "__dataclass_fields__"): + return fdl.cast(run.Config, fdl_dc.convert_dataclasses_to_configs(obj)) + return obj + + # LLAMA3 8B Recipe @run.cli.factory(target=llm.pretrain) def cloudai_llama3_8b_recipe() -> run.Partial: @@ -833,104 +855,27 @@ def cloudai_llama3_70b_recipe() -> run.Partial: # LLAMA3 405B Recipe @run.cli.factory(target=llm.pretrain) def cloudai_llama3_405b_recipe() -> run.Partial: - recipe = run.Partial( - llm.pretrain, - model=run.Config(LlamaModel, config=Llama31Config405B()), - data=run.Config( - MockDataModule, - seq_length=8192, - micro_batch_size=1, - global_batch_size=8, - tokenizer=null_tokenizer(vocab_size=128256), - ), - trainer=run.Config( - nl.Trainer, - devices=8, - num_nodes=1, - accelerator="gpu", - max_steps=10, - limit_test_batches=50, - limit_val_batches=32, - log_every_n_steps=10, - accumulate_grad_batches=1, - plugins=run.Config( - nl.MegatronMixedPrecision, - precision="bf16-mixed", - params_dtype=torch.bfloat16, - pipeline_dtype=torch.bfloat16, - autocast_enabled=False, - grad_reduce_in_fp32=False, - ), - strategy=run.Config( - nl.MegatronStrategy, - tensor_model_parallel_size=8, - pipeline_model_parallel_size=1, - context_parallel_size=2, - virtual_pipeline_model_parallel_size=8, - sequence_parallel=True, - expert_model_parallel_size=1, - expert_tensor_parallel_size=None, - pipeline_dtype=torch.bfloat16, - gradient_as_bucket_view=True, - ckpt_async_save=True, - ckpt_parallel_load=True, - ddp=run.Config( - DistributedDataParallelConfig, - check_for_nan_in_grad=False, - grad_reduce_in_fp32=True, - overlap_grad_reduce=True, - overlap_param_gather=True, - ), - ), - num_sanity_val_steps=0, - use_distributed_sampler=False, - val_check_interval=1000, - max_epochs=10, - callbacks=[ - timing_callback(), - ], - ), - optim=run.Config( - nl.MegatronOptimizerModule, - config=run.Config( - OptimizerConfig, - lr=0.0003, - bf16=True, - use_precision_aware_optimizer=True, - params_dtype=torch.bfloat16, - use_distributed_optimizer=True, - weight_decay=0.1, - adam_beta1=0.9, - adam_beta2=0.95, - adam_eps=1e-05, - clip_grad=1.0, - ), - lr_scheduler=run.Config( - CosineAnnealingScheduler, - warmup_steps=2000, - constant_steps=0, - min_lr=2.9999999999999997e-05, - ), - ), - resume=run.Config( - nl.AutoResume, - resume_if_exists=True, - resume_ignore_no_checkpoint=True, - resume_past_end=True, - ), - ) + recipe = llama31_405_pretrain_recipe(performance_mode=True) + + # Optional tokenizer override for CloudAI runs + recipe.data.tokenizer = null_tokenizer(vocab_size=128256) + recipe.model.config.expert_tensor_parallel_size = None recipe.model.config.seq_length = 8192 tp_overlap_cfg, tp_comm_overlap = get_tp_overlap_config() - megatron_comm_overlap_callback = run.Config( - MegatronCommOverlapCallback, - tp_comm_overlap=tp_comm_overlap, - tp_comm_overlap_cfg=tp_overlap_cfg, - overlap_param_gather_with_optimizer_step=True, - defer_embedding_wgrad_compute=True, - wgrad_deferral_limit=50, - ) + + # Locate and update the existing comm-overlap callback to match our env + comm_overlap_callback_idx = get_comm_overlap_callback_idx(recipe.trainer.callbacks) + assert comm_overlap_callback_idx is not None, "MegatronCommOverlapCallback missing. Required for performance." + + tp_comm_overlap_cfg = to_run_config(tp_overlap_cfg) + cb = recipe.trainer.callbacks[comm_overlap_callback_idx] + cb.tp_comm_overlap = tp_comm_overlap + cb.tp_comm_overlap_cfg = tp_comm_overlap_cfg + cb.overlap_param_gather_with_optimizer_step = True + cb.defer_embedding_wgrad_compute = True + cb.wgrad_deferral_limit = 50 enable_fsdp = os.getenv("CLOUDAI_ENABLE_FSDP", "0") == "1" disable_tp_commd_overlap = os.getenv("CLOUDAI_DISABLE_TP_COMM_OVERLAP", "0") == "1" @@ -942,12 +887,13 @@ def cloudai_llama3_405b_recipe() -> run.Partial: recipe.trainer.strategy.ddp.average_in_collective = False recipe.trainer.strategy.ddp.keep_fp8_transpose_cache_when_using_custom_fsdp = False recipe.model.config.gradient_accumulation_fusion = False - megatron_comm_overlap_callback.defer_embedding_wgrad_compute = False - megatron_comm_overlap_callback.wgrad_deferral_limit = 50 - megatron_comm_overlap_callback.overlap_param_gather_with_optimizer_step = False + + cb.defer_embedding_wgrad_compute = False + cb.wgrad_deferral_limit = 50 + cb.overlap_param_gather_with_optimizer_step = False if disable_tp_commd_overlap: - megatron_comm_overlap_callback.tp_comm_overlap = False + cb.tp_comm_overlap = False recompute_layers = int(os.getenv("CLOUDAI_RECOMPUTE_LAYERS", "0")) if recompute_layers > 0: @@ -969,7 +915,6 @@ def cloudai_llama3_405b_recipe() -> run.Partial: model_name="llama3", ) ) - recipe.trainer.callbacks.append(megatron_comm_overlap_callback) recipe.trainer.callbacks.append(run.Config(GarbageCollectionCallback, gc_interval_train=100, gc_interval_val=100)) recipe.trainer.strategy.account_for_embedding_in_pipeline_split = True recipe.trainer.strategy.account_for_loss_in_pipeline_split = True From 46c6ecc168efa1905f6a79c05c1aba9f8c2bc02d Mon Sep 17 00:00:00 2001 From: Srivatsan Krishnan Date: Thu, 14 Aug 2025 14:15:31 -0700 Subject: [PATCH 24/28] remove the fdl --- src/cloudai/workloads/nemo_run/cloudai_nemorun.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/cloudai/workloads/nemo_run/cloudai_nemorun.py b/src/cloudai/workloads/nemo_run/cloudai_nemorun.py index 6b89eff9c..13fd1db53 100644 --- a/src/cloudai/workloads/nemo_run/cloudai_nemorun.py +++ b/src/cloudai/workloads/nemo_run/cloudai_nemorun.py @@ -20,8 +20,6 @@ import lightning.pytorch as pl import nemo_run as run import torch -import fiddle as fdl -import fiddle.dataclasses as fdl_dc from lightning.pytorch.loggers import TensorBoardLogger from lightning.pytorch.loggers.wandb import WandbLogger from megatron.core.distributed import DistributedDataParallelConfig @@ -635,8 +633,6 @@ def get_comm_overlap_callback_idx(callbacks: list) -> Optional[int]: # Convert dataclass-based TP overlap cfgs to run.Config when needed def to_run_config(obj): - if hasattr(obj, "__dataclass_fields__"): - return fdl.cast(run.Config, fdl_dc.convert_dataclasses_to_configs(obj)) return obj From 1700f6bed53560dbc071a91d6465b3dcfdeb74d4 Mon Sep 17 00:00:00 2001 From: Srivatsan Krishnan Date: Thu, 14 Aug 2025 17:25:11 -0700 Subject: [PATCH 25/28] remove tp_comm_overp disable with fsdp --- src/cloudai/workloads/nemo_run/cloudai_nemorun.py | 2 +- .../workloads/nemo_run/slurm_command_gen_strategy.py | 9 --------- 2 files changed, 1 insertion(+), 10 deletions(-) diff --git a/src/cloudai/workloads/nemo_run/cloudai_nemorun.py b/src/cloudai/workloads/nemo_run/cloudai_nemorun.py index 13fd1db53..915953f8f 100644 --- a/src/cloudai/workloads/nemo_run/cloudai_nemorun.py +++ b/src/cloudai/workloads/nemo_run/cloudai_nemorun.py @@ -32,7 +32,7 @@ from nemo.collections.llm.gpt.model.llama import Llama3Config8B, Llama3Config70B, Llama31Config405B, LlamaModel from nemo.collections.llm.gpt.model.nemotron import Nemotron4Config15B, Nemotron4Config340B, NemotronModel from nemo.collections.llm.recipes.nemotron3_8b import pretrain_recipe as nemotron3_8b_recipe -from nemo.collections.llm.recipes.llama31_405 import pretrain_recipe as llama31_405_pretrain_recipe +from nemo.collections.llm.recipes.llama31_405b import pretrain_recipe as llama31_405_pretrain_recipe from nemo.collections.llm.recipes.tp_overlap_configs.userbuffers import ( BulkOverlapCfg, PipelineOverlapCfg, diff --git a/src/cloudai/workloads/nemo_run/slurm_command_gen_strategy.py b/src/cloudai/workloads/nemo_run/slurm_command_gen_strategy.py index f03fc25d7..748f8f5bd 100644 --- a/src/cloudai/workloads/nemo_run/slurm_command_gen_strategy.py +++ b/src/cloudai/workloads/nemo_run/slurm_command_gen_strategy.py @@ -49,15 +49,6 @@ def _set_additional_env_vars(self, tdef: NeMoRunTestDefinition): logging.debug("Setting NCCL_P2P_NET_CHUNKSIZE to 2097152 as pipeline_model_parallel_size is greater than 1") self.final_env_vars["NCCL_P2P_NET_CHUNKSIZE"] = "2097152" - enable_fsdp = self.final_env_vars.get("CLOUDAI_ENABLE_FSDP", "0") - if enable_fsdp == "1": - logging.info( - ( - "CLOUDAI_ENABLE_FSDP is set to 1. Currently, NemoRun does not support FSDP " - "with TP communication overlap." - ) - ) - self.final_env_vars["CLOUDAI_DISABLE_TP_COMM_OVERLAP"] = "1" def _run_script(self) -> Path: tdef: NeMoRunTestDefinition = cast(NeMoRunTestDefinition, self.test_run.test.test_definition) From 68490ac87bf51bcda85162230f0243b5f115ed81 Mon Sep 17 00:00:00 2001 From: Srivatsan Krishnan Date: Fri, 15 Aug 2025 10:54:42 -0700 Subject: [PATCH 26/28] use nemo userbuffers directly --- .../workloads/nemo_run/cloudai_nemorun.py | 37 +++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/src/cloudai/workloads/nemo_run/cloudai_nemorun.py b/src/cloudai/workloads/nemo_run/cloudai_nemorun.py index 915953f8f..15841fe69 100644 --- a/src/cloudai/workloads/nemo_run/cloudai_nemorun.py +++ b/src/cloudai/workloads/nemo_run/cloudai_nemorun.py @@ -38,6 +38,10 @@ PipelineOverlapCfg, RingExchangeOverlapCfg, TransformerLayerTPOverlapCfg, + userbuffers_bf16_b200_h16384_tp4_cp2_mbs1_seqlen8192, + userbuffers_bf16_h100_h16384_tp8_cp2_mbs1_seqlen8192, + userbuffers_fp8_b200_h16384_tp4_cp2_mbs1_seqlen8192, + userbuffers_fp8_h100_h16384_tp8_cp2_mbs1_seqlen8192, ) from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer from nemo.lightning import AutoResume, NeMoLogger @@ -597,6 +601,39 @@ def llama3_70b_fp8_h100_tp_overlap_config() -> run.Config[TransformerLayerTPOver def get_tp_overlap_config(): gpu_type = os.getenv("CLOUDAI_GPU_TYPE") compute_dtype = os.getenv("CLOUDAI_GPU_DTYPE") + recipe_name = os.getenv("CLOUDAI_NEMO_RECIPE", "") + is_405b = "405b" in recipe_name.lower() + + # Use upstream userbuffer presets for Llama3.1 405B + if is_405b: + ub_cfg = { + "h100": { + "bf16": userbuffers_bf16_h100_h16384_tp8_cp2_mbs1_seqlen8192, + "fp8": userbuffers_fp8_h100_h16384_tp8_cp2_mbs1_seqlen8192, + }, + "b200": { + "bf16": userbuffers_bf16_b200_h16384_tp4_cp2_mbs1_seqlen8192, + "fp8": userbuffers_fp8_b200_h16384_tp4_cp2_mbs1_seqlen8192, + }, + "gb200": { + "bf16": userbuffers_bf16_b200_h16384_tp4_cp2_mbs1_seqlen8192, + "fp8": userbuffers_fp8_b200_h16384_tp4_cp2_mbs1_seqlen8192, + }, + } + fn = (ub_cfg.get(gpu_type, {}) or {}).get(compute_dtype) + if fn is not None: + tp_overlap_cfg = fn() + tp_comm_overlap = True + else: + print( + "Warning: Not using Default Comm Overlap Config.\n" + "Please set the GPU type and compute dtype in the environment variables." + ) + tp_overlap_cfg = None + tp_comm_overlap = False + return tp_overlap_cfg, tp_comm_overlap + + # Fallback: retain 70B-oriented configs for non-405B runs if gpu_type == "h100" and compute_dtype == "bf16": tp_overlap_cfg = llama3_70b_bf16_h100_tp_overlap_config() tp_comm_overlap = True From 0b126c014dd65633ed0f076ac6ec5cc66059e6c2 Mon Sep 17 00:00:00 2001 From: Srivatsan Krishnan Date: Fri, 15 Aug 2025 11:15:53 -0700 Subject: [PATCH 27/28] 'TransformerLayerTPOverlapCfg' object is not callable fix --- src/cloudai/workloads/nemo_run/cloudai_nemorun.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/cloudai/workloads/nemo_run/cloudai_nemorun.py b/src/cloudai/workloads/nemo_run/cloudai_nemorun.py index 15841fe69..0f1222f47 100644 --- a/src/cloudai/workloads/nemo_run/cloudai_nemorun.py +++ b/src/cloudai/workloads/nemo_run/cloudai_nemorun.py @@ -620,9 +620,9 @@ def get_tp_overlap_config(): "fp8": userbuffers_fp8_b200_h16384_tp4_cp2_mbs1_seqlen8192, }, } - fn = (ub_cfg.get(gpu_type, {}) or {}).get(compute_dtype) - if fn is not None: - tp_overlap_cfg = fn() + cfg_or_factory = (ub_cfg.get(gpu_type, {}) or {}).get(compute_dtype) + if cfg_or_factory is not None: + tp_overlap_cfg = cfg_or_factory() if callable(cfg_or_factory) else cfg_or_factory tp_comm_overlap = True else: print( From b00ecc9f6e61184664d6ab56ede900d015018465 Mon Sep 17 00:00:00 2001 From: Malay Nagda Date: Mon, 18 Aug 2025 18:27:29 +0530 Subject: [PATCH 28/28] fsdp and ub cfgs Signed-off-by: Malay Nagda --- src/cloudai/workloads/nemo_run/cloudai_nemorun.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/cloudai/workloads/nemo_run/cloudai_nemorun.py b/src/cloudai/workloads/nemo_run/cloudai_nemorun.py index 0f1222f47..82140ec4b 100644 --- a/src/cloudai/workloads/nemo_run/cloudai_nemorun.py +++ b/src/cloudai/workloads/nemo_run/cloudai_nemorun.py @@ -622,7 +622,7 @@ def get_tp_overlap_config(): } cfg_or_factory = (ub_cfg.get(gpu_type, {}) or {}).get(compute_dtype) if cfg_or_factory is not None: - tp_overlap_cfg = cfg_or_factory() if callable(cfg_or_factory) else cfg_or_factory + tp_overlap_cfg = cfg_or_factory tp_comm_overlap = True else: print( @@ -922,7 +922,6 @@ def cloudai_llama3_405b_recipe() -> run.Partial: recipe.model.config.gradient_accumulation_fusion = False cb.defer_embedding_wgrad_compute = False - cb.wgrad_deferral_limit = 50 cb.overlap_param_gather_with_optimizer_step = False if disable_tp_commd_overlap: @@ -958,6 +957,16 @@ def cloudai_llama3_405b_recipe() -> run.Partial: if os.getenv("CLOUDAI_GPU_TYPE") in ["b200", "gb200"] and os.getenv("CLOUDAI_GPU_DTYPE") == "fp8": print("Info: use_precision_aware_optimizer is set to False for fp8 on b200/gb200 GPUs.") recipe.optim.config.use_precision_aware_optimizer = False + + recipe.trainer.callbacks[comm_overlap_callback_idx] = cb + + gpu_type = os.getenv("CLOUDAI_GPU_TYPE") + gpu_type = gpu_type.lower() if gpu_type else None + use_mcore_fsdp = bool(int(os.getenv("CLOUDAI_ENABLE_FSDP", "0"))) + + if use_mcore_fsdp and gpu_type == "gb200": + recipe.trainer.strategy.num_distributed_optimizer_instances = (recipe.trainer.num_nodes * 4) // 64 + return recipe