diff --git a/src/cloudai/cli/handlers.py b/src/cloudai/cli/handlers.py index 85703c1b7..afd87362d 100644 --- a/src/cloudai/cli/handlers.py +++ b/src/cloudai/cli/handlers.py @@ -132,7 +132,29 @@ def handle_dse_job(runner: Runner, args: argparse.Namespace): ) continue - agent = agent_class(env) + agent_config = test_run.test.test_definition.agent_config + + if agent_config: + agent_kwargs = { + k: v for k, v in agent_config.model_dump().items() + if v is not None and k not in ['extra_params', 'seed_parameters', 'agent_type'] + } + + if hasattr(agent_config, 'seed_parameters') and agent_config.seed_parameters: + action_space = env.define_action_space() + resolved_seeds = test_run.test.test_definition.resolve_seed_parameters(action_space) + if resolved_seeds: + agent_kwargs['seed_parameters'] = resolved_seeds + + agent_kwargs.update(agent_config.extra_params) + + try: + agent = agent_class(env, **agent_kwargs) + except TypeError as e: + logging.warning(f"Agent {agent_type} doesn't support some configuration parameters: {e}") + agent = agent_class(env) + else: + agent = agent_class(env) for step in range(agent.max_steps): result = agent.select_action() if result is None: diff --git a/src/cloudai/configurator/cloudai_gym.py b/src/cloudai/configurator/cloudai_gym.py index 915be08b7..ea2e4ca6a 100644 --- a/src/cloudai/configurator/cloudai_gym.py +++ b/src/cloudai/configurator/cloudai_gym.py @@ -102,7 +102,7 @@ def step(self, action: Any) -> Tuple[list, float, bool, dict]: if not self.test_run.test.test_definition.constraint_check(self.test_run): logging.info("Constraint check failed. Skipping step.") - return [-1.0], -1.0, True, {} + return [-1.0], -1e-6, True, {} logging.info(f"Running step {self.test_run.step} with action {action}") new_tr = copy.deepcopy(self.test_run) diff --git a/src/cloudai/models/scenario.py b/src/cloudai/models/scenario.py index 6a81e2583..3db943d42 100644 --- a/src/cloudai/models/scenario.py +++ b/src/cloudai/models/scenario.py @@ -17,12 +17,12 @@ from __future__ import annotations from pathlib import Path -from typing import Any, Literal, Optional +from typing import Any, Literal, Optional, Union from pydantic import BaseModel, ConfigDict, Field, ValidationError, field_serializer, field_validator, model_validator from cloudai.core import CmdArgs, GitRepo, NsysConfiguration, Registry, Reporter, TestRun -from cloudai.models.workload import TestDefinition +from cloudai.models.workload import AgentConfig, BOAgentConfig, TestDefinition def parse_reports_spec( @@ -91,9 +91,34 @@ class TestRunModel(BaseModel): agent: Optional[str] = None agent_steps: Optional[int] = None agent_metrics: list[str] = Field(default=["default"]) + agent_config: Optional[Union[AgentConfig, BOAgentConfig]] = None + + @field_validator('agent_config', mode='before') + @classmethod + def parse_agent_config(cls, v, info): + """Parse agent_config based on the agent type.""" + + if v is None: + return None + + if isinstance(v, AgentConfig): + return v + + if isinstance(v, dict): + has_bo_fields = {'sobol_num_trials', 'botorch_num_trials', 'seed_parameters'} & v.keys() + + is_bo_agent = v.get('agent_type') == 'bo_gp' + + if has_bo_fields or is_bo_agent: + return BOAgentConfig.model_validate(v) + else: + return AgentConfig.model_validate(v) + + return v def tdef_model_dump(self, by_alias: bool) -> dict: """Return a dictionary with non-None values that correspond to the test definition fields.""" + agent_config_dump = self.agent_config.model_dump() if self.agent_config else None data = { "name": self.name, "description": self.description, @@ -101,6 +126,7 @@ def tdef_model_dump(self, by_alias: bool) -> dict: "agent": self.agent, "agent_steps": self.agent_steps, "agent_metrics": self.agent_metrics, + "agent_config": agent_config_dump, "extra_container_mounts": self.extra_container_mounts, "extra_env_vars": self.extra_env_vars if self.extra_env_vars else None, "cmd_args": self.cmd_args.model_dump(by_alias=by_alias) if self.cmd_args else None, diff --git a/src/cloudai/models/workload.py b/src/cloudai/models/workload.py index 1745ae734..27f6cfbe4 100644 --- a/src/cloudai/models/workload.py +++ b/src/cloudai/models/workload.py @@ -18,11 +18,55 @@ from dataclasses import dataclass from typing import Any, Dict, List, Optional, Union -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, ConfigDict, Field, field_validator from cloudai.core import GitRepo, Installable, JobStatusResult, PythonExecutable, TestRun +class AgentConfig(BaseModel): + """Base configuration class for agents used in DSE.""" + + model_config = ConfigDict(extra="forbid") + + # Common agent parameters + random_seed: Optional[int] = None + + # Allow for additional agent-specific parameters + extra_params: Dict[str, Any] = Field(default_factory=dict) + + +class BOAgentConfig(AgentConfig): + """Configuration for Bayesian Optimization agent.""" + + # Add discriminator field to identify this as a BO config + agent_type: str = "bo_gp" + + # BO-specific parameters + sobol_num_trials: Optional[int] = None + botorch_num_trials: Optional[int] = None + + # Seed parameters for starting optimization from known configuration + seed_parameters: Optional[Dict[str, Any]] = None + + # Allow for additional agent-specific parameters + extra_params: Dict[str, Any] = Field(default_factory=dict) + + def __init__(self, **data): + # Ensure agent_type is always set even if not in input data + if 'agent_type' not in data: + data['agent_type'] = 'bo_gp' + super().__init__(**data) + + def model_dump(self, **kwargs): + """Override model_dump to ensure all BO fields are preserved.""" + # Force exclude_none=False to preserve all fields + kwargs['exclude_none'] = False + result = super().model_dump(**kwargs) + # Ensure agent_type is always included to identify this as BO config + result['agent_type'] = self.agent_type + return result + + class CmdArgs(BaseModel): """Test command arguments.""" @@ -107,6 +151,70 @@ class TestDefinition(BaseModel, ABC): agent_steps: int = 1 agent_metrics: list[str] = Field(default=["default"]) agent_reward_function: str = "inverse" + agent_config: Optional[Union[AgentConfig, BOAgentConfig]] = None + + @field_validator('agent_config', mode='before') + @classmethod + def parse_agent_config(cls, v, info): + """Parse agent_config based on the agent type.""" + + if v is None: + return None + + if isinstance(v, AgentConfig): + return v + + if isinstance(v, dict): + # Check for BO-specific fields directly instead of relying on agent field + # since field validation order means agent might not be available yet + has_bo_fields = {'sobol_num_trials', 'botorch_num_trials', 'seed_parameters'} & v.keys() + + # Also check for agent_type discriminator field + is_bo_agent = v.get('agent_type') == 'bo_gp' + + if has_bo_fields or is_bo_agent: + # Use BOAgentConfig when BO-specific fields are present or agent_type indicates BO + return BOAgentConfig.model_validate(v) + else: + # Fall back to base AgentConfig for other cases + return AgentConfig.model_validate(v) + + return v + + def resolve_seed_parameters(self, action_space: Dict[str, Any]) -> Optional[Dict[str, Any]]: + """ + Resolve seed parameters by extracting values from the action space. + + Args: + action_space: The flattened action space from cmd_args + + Returns: + Resolved seed parameters with actual values + """ + if not self.agent_config or not hasattr(self.agent_config, 'seed_parameters'): + return None + + seed_params = self.agent_config.seed_parameters + if not seed_params: + return None + + resolved = {} + for param_name, value_spec in seed_params.items(): + if param_name in action_space: + param_options = action_space[param_name] + if isinstance(param_options, list): + if value_spec in param_options: + resolved[param_name] = value_spec + elif isinstance(value_spec, int) and 0 <= value_spec < len(param_options): + resolved[param_name] = param_options[value_spec] + else: + resolved[param_name] = param_options[0] + else: + resolved[param_name] = param_options + else: + resolved[param_name] = value_spec + + return resolved @property def cmd_args_dict(self) -> Dict[str, Union[str, List[str]]]: diff --git a/src/cloudai/test_scenario_parser.py b/src/cloudai/test_scenario_parser.py index e39c1c449..6e0f37f92 100644 --- a/src/cloudai/test_scenario_parser.py +++ b/src/cloudai/test_scenario_parser.py @@ -235,15 +235,15 @@ def _prepare_tdef(self, test_info: TestRunModel) -> Tuple[Test, TestDefinition]: if test_info.test_name not in self.test_mapping: raise MissingTestError(test_info.test_name) test = self.test_mapping[test_info.test_name] - - test_defined = test.test_definition.model_dump(by_alias=True) - tc_defined = test_info.tdef_model_dump(by_alias=True) + + test_defined = test.test_definition.model_dump(by_alias=True, exclude_none=False) + tc_defined = test_info.tdef_model_dump(by_alias=True, exclude_none=False) merged_data = deep_merge(test_defined, tc_defined) + test.test_definition = tp.load_test_definition(merged_data, self.strict) - elif test_info.test_template_name: # test fully defined in the scenario - test = tp._parse_data(test_info.tdef_model_dump(by_alias=True), self.strict) + elif test_info.test_template_name: + test = tp._parse_data(test_info.tdef_model_dump(by_alias=True, exclude_none=False), self.strict) else: - # this should never happen, because we check for this in the modelvalidator raise ValueError( f"Cannot configure test case '{test_info.id}' with both 'test_name' and 'test_template_name'." ) diff --git a/src/cloudai/workloads/nemo_run/cloudai_nemorun.py b/src/cloudai/workloads/nemo_run/cloudai_nemorun.py index db6109751..9502710f7 100644 --- a/src/cloudai/workloads/nemo_run/cloudai_nemorun.py +++ b/src/cloudai/workloads/nemo_run/cloudai_nemorun.py @@ -32,11 +32,16 @@ from nemo.collections.llm.gpt.model.llama import Llama3Config70B, Llama31Config405B, LlamaModel from nemo.collections.llm.gpt.model.nemotron import Nemotron4Config15B, Nemotron4Config340B, NemotronModel from nemo.collections.llm.recipes.nemotron3_8b import pretrain_recipe as nemotron3_8b_recipe +from nemo.collections.llm.recipes.llama31_405b import pretrain_recipe as llama31_405_pretrain_recipe from nemo.collections.llm.recipes.tp_overlap_configs.userbuffers import ( BulkOverlapCfg, PipelineOverlapCfg, RingExchangeOverlapCfg, TransformerLayerTPOverlapCfg, + userbuffers_bf16_b200_h16384_tp4_cp2_mbs1_seqlen8192, + userbuffers_bf16_h100_h16384_tp8_cp2_mbs1_seqlen8192, + userbuffers_fp8_b200_h16384_tp4_cp2_mbs1_seqlen8192, + userbuffers_fp8_h100_h16384_tp8_cp2_mbs1_seqlen8192, ) from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer from nemo.lightning import AutoResume, NeMoLogger @@ -596,6 +601,39 @@ def llama3_70b_fp8_h100_tp_overlap_config() -> run.Config[TransformerLayerTPOver def get_tp_overlap_config(): gpu_type = os.getenv("CLOUDAI_GPU_TYPE") compute_dtype = os.getenv("CLOUDAI_GPU_DTYPE") + recipe_name = os.getenv("CLOUDAI_NEMO_RECIPE", "") + is_405b = "405b" in recipe_name.lower() + + # Use upstream userbuffer presets for Llama3.1 405B + if is_405b: + ub_cfg = { + "h100": { + "bf16": userbuffers_bf16_h100_h16384_tp8_cp2_mbs1_seqlen8192, + "fp8": userbuffers_fp8_h100_h16384_tp8_cp2_mbs1_seqlen8192, + }, + "b200": { + "bf16": userbuffers_bf16_b200_h16384_tp4_cp2_mbs1_seqlen8192, + "fp8": userbuffers_fp8_b200_h16384_tp4_cp2_mbs1_seqlen8192, + }, + "gb200": { + "bf16": userbuffers_bf16_b200_h16384_tp4_cp2_mbs1_seqlen8192, + "fp8": userbuffers_fp8_b200_h16384_tp4_cp2_mbs1_seqlen8192, + }, + } + cfg_or_factory = (ub_cfg.get(gpu_type, {}) or {}).get(compute_dtype) + if cfg_or_factory is not None: + tp_overlap_cfg = cfg_or_factory + tp_comm_overlap = True + else: + print( + "Warning: Not using Default Comm Overlap Config.\n" + "Please set the GPU type and compute dtype in the environment variables." + ) + tp_overlap_cfg = None + tp_comm_overlap = False + return tp_overlap_cfg, tp_comm_overlap + + # Fallback: retain 70B-oriented configs for non-405B runs if gpu_type == "h100" and compute_dtype == "bf16": tp_overlap_cfg = llama3_70b_bf16_h100_tp_overlap_config() tp_comm_overlap = True @@ -618,6 +656,22 @@ def get_tp_overlap_config(): return tp_overlap_cfg, tp_comm_overlap +# Find the index of MegatronCommOverlapCallback in an existing callbacks list +def get_comm_overlap_callback_idx(callbacks: list) -> Optional[int]: + for idx, cb in enumerate(callbacks): + fn_or_cls = getattr(cb, "__fn_or_cls__", None) + if fn_or_cls is MegatronCommOverlapCallback: + return idx + target = getattr(cb, "target", None) or getattr(cb, "cls", None) + if target is MegatronCommOverlapCallback: + return idx + return None + + +# Convert dataclass-based TP overlap cfgs to run.Config when needed +def to_run_config(obj): + return obj + def set_perf_optimization_configs(recipe): recipe.model.config.cross_entropy_fusion_impl = "te" @@ -630,6 +684,7 @@ def set_perf_optimization_configs(recipe): return recipe + # LLAMA3 8B Recipe @run.cli.factory(target=llm.pretrain) def cloudai_llama3_8b_recipe() -> run.Partial: @@ -771,104 +826,27 @@ def cloudai_llama3_70b_recipe() -> run.Partial: # LLAMA3 405B Recipe @run.cli.factory(target=llm.pretrain) def cloudai_llama3_405b_recipe() -> run.Partial: - recipe = run.Partial( - llm.pretrain, - model=run.Config(LlamaModel, config=Llama31Config405B()), - data=run.Config( - MockDataModule, - seq_length=8192, - micro_batch_size=1, - global_batch_size=8, - tokenizer=null_tokenizer(vocab_size=128256), - ), - trainer=run.Config( - nl.Trainer, - devices=8, - num_nodes=1, - accelerator="gpu", - max_steps=10, - limit_test_batches=50, - limit_val_batches=32, - log_every_n_steps=10, - accumulate_grad_batches=1, - plugins=run.Config( - nl.MegatronMixedPrecision, - precision="bf16-mixed", - params_dtype=torch.bfloat16, - pipeline_dtype=torch.bfloat16, - autocast_enabled=False, - grad_reduce_in_fp32=False, - ), - strategy=run.Config( - nl.MegatronStrategy, - tensor_model_parallel_size=8, - pipeline_model_parallel_size=1, - context_parallel_size=2, - virtual_pipeline_model_parallel_size=8, - sequence_parallel=True, - expert_model_parallel_size=1, - expert_tensor_parallel_size=None, - pipeline_dtype=torch.bfloat16, - gradient_as_bucket_view=True, - ckpt_async_save=True, - ckpt_parallel_load=True, - ddp=run.Config( - DistributedDataParallelConfig, - check_for_nan_in_grad=False, - grad_reduce_in_fp32=True, - overlap_grad_reduce=True, - overlap_param_gather=True, - ), - ), - num_sanity_val_steps=0, - use_distributed_sampler=False, - val_check_interval=1000, - max_epochs=10, - callbacks=[ - timing_callback(), - ], - ), - optim=run.Config( - nl.MegatronOptimizerModule, - config=run.Config( - OptimizerConfig, - lr=0.0003, - bf16=True, - use_precision_aware_optimizer=True, - params_dtype=torch.bfloat16, - use_distributed_optimizer=True, - weight_decay=0.1, - adam_beta1=0.9, - adam_beta2=0.95, - adam_eps=1e-05, - clip_grad=1.0, - ), - lr_scheduler=run.Config( - CosineAnnealingScheduler, - warmup_steps=2000, - constant_steps=0, - min_lr=2.9999999999999997e-05, - ), - ), - resume=run.Config( - nl.AutoResume, - resume_if_exists=True, - resume_ignore_no_checkpoint=True, - resume_past_end=True, - ), - ) + recipe = llama31_405_pretrain_recipe(performance_mode=True) + + # Optional tokenizer override for CloudAI runs + recipe.data.tokenizer = null_tokenizer(vocab_size=128256) + recipe.model.config.expert_tensor_parallel_size = None recipe.model.config.seq_length = 8192 tp_overlap_cfg, tp_comm_overlap = get_tp_overlap_config() - megatron_comm_overlap_callback = run.Config( - MegatronCommOverlapCallback, - tp_comm_overlap=tp_comm_overlap, - tp_comm_overlap_cfg=tp_overlap_cfg, - overlap_param_gather_with_optimizer_step=True, - defer_embedding_wgrad_compute=True, - wgrad_deferral_limit=50, - ) + + # Locate and update the existing comm-overlap callback to match our env + comm_overlap_callback_idx = get_comm_overlap_callback_idx(recipe.trainer.callbacks) + assert comm_overlap_callback_idx is not None, "MegatronCommOverlapCallback missing. Required for performance." + + tp_comm_overlap_cfg = to_run_config(tp_overlap_cfg) + cb = recipe.trainer.callbacks[comm_overlap_callback_idx] + cb.tp_comm_overlap = tp_comm_overlap + cb.tp_comm_overlap_cfg = tp_comm_overlap_cfg + cb.overlap_param_gather_with_optimizer_step = True + cb.defer_embedding_wgrad_compute = True + cb.wgrad_deferral_limit = 50 enable_fsdp = os.getenv("CLOUDAI_ENABLE_FSDP", "0") == "1" disable_tp_commd_overlap = os.getenv("CLOUDAI_DISABLE_TP_COMM_OVERLAP", "0") == "1" @@ -880,12 +858,12 @@ def cloudai_llama3_405b_recipe() -> run.Partial: recipe.trainer.strategy.ddp.average_in_collective = False recipe.trainer.strategy.ddp.keep_fp8_transpose_cache_when_using_custom_fsdp = False recipe.model.config.gradient_accumulation_fusion = False - megatron_comm_overlap_callback.defer_embedding_wgrad_compute = False - megatron_comm_overlap_callback.wgrad_deferral_limit = 50 - megatron_comm_overlap_callback.overlap_param_gather_with_optimizer_step = False + + cb.defer_embedding_wgrad_compute = False + cb.overlap_param_gather_with_optimizer_step = False if disable_tp_commd_overlap: - megatron_comm_overlap_callback.tp_comm_overlap = False + cb.tp_comm_overlap = False recompute_layers = int(os.getenv("CLOUDAI_RECOMPUTE_LAYERS", "0")) if recompute_layers > 0: @@ -907,7 +885,6 @@ def cloudai_llama3_405b_recipe() -> run.Partial: model_name="llama3", ) ) - recipe.trainer.callbacks.append(megatron_comm_overlap_callback) recipe.trainer.callbacks.append(run.Config(GarbageCollectionCallback, gc_interval_train=100, gc_interval_val=100)) recipe.trainer.strategy.account_for_embedding_in_pipeline_split = True recipe.trainer.strategy.account_for_loss_in_pipeline_split = True @@ -918,6 +895,16 @@ def cloudai_llama3_405b_recipe() -> run.Partial: if os.getenv("CLOUDAI_GPU_TYPE") in ["b200", "gb200"] and os.getenv("CLOUDAI_GPU_DTYPE") == "fp8": print("Info: use_precision_aware_optimizer is set to False for fp8 on b200/gb200 GPUs.") recipe.optim.config.use_precision_aware_optimizer = False + + recipe.trainer.callbacks[comm_overlap_callback_idx] = cb + + gpu_type = os.getenv("CLOUDAI_GPU_TYPE") + gpu_type = gpu_type.lower() if gpu_type else None + use_mcore_fsdp = bool(int(os.getenv("CLOUDAI_ENABLE_FSDP", "0"))) + + if use_mcore_fsdp and gpu_type == "gb200": + recipe.trainer.strategy.num_distributed_optimizer_instances = (recipe.trainer.num_nodes * 4) // 64 + return recipe diff --git a/src/cloudai/workloads/nemo_run/nemo_run.py b/src/cloudai/workloads/nemo_run/nemo_run.py index b0094ea9b..8be9a7b38 100644 --- a/src/cloudai/workloads/nemo_run/nemo_run.py +++ b/src/cloudai/workloads/nemo_run/nemo_run.py @@ -82,6 +82,7 @@ class Trainer(BaseModel): max_steps: Union[int, List[int]] = 100 val_check_interval: Union[int, float, list[Union[int, float]]] = 1000 num_nodes: Optional[int] = None # sweeps are done via TestRun.num_nodes + devices: Optional[int] = None strategy: TrainerStrategy = Field(default_factory=TrainerStrategy) plugins: Optional[Plugin] = None callbacks: Optional[Union[str, list[str]]] = None @@ -150,7 +151,7 @@ def constraint_check(self, tr: TestRun) -> bool: pp = cast(int, self.cmd_args.trainer.strategy.pipeline_model_parallel_size) cp = cast(int, self.cmd_args.trainer.strategy.context_parallel_size) vp = cast(Optional[int], self.cmd_args.trainer.strategy.virtual_pipeline_model_parallel_size) - num_gpus = tr.nnodes * 8 + num_gpus = tr.nnodes * (self.cmd_args.trainer.devices if self.cmd_args.trainer.devices else 8) num_layers = cast(int, self.cmd_args.num_layers) dp = num_gpus // (tp * pp * cp) mbs = cast(int, self.cmd_args.data.micro_batch_size) diff --git a/src/cloudai/workloads/nemo_run/slurm_command_gen_strategy.py b/src/cloudai/workloads/nemo_run/slurm_command_gen_strategy.py index f03fc25d7..748f8f5bd 100644 --- a/src/cloudai/workloads/nemo_run/slurm_command_gen_strategy.py +++ b/src/cloudai/workloads/nemo_run/slurm_command_gen_strategy.py @@ -49,15 +49,6 @@ def _set_additional_env_vars(self, tdef: NeMoRunTestDefinition): logging.debug("Setting NCCL_P2P_NET_CHUNKSIZE to 2097152 as pipeline_model_parallel_size is greater than 1") self.final_env_vars["NCCL_P2P_NET_CHUNKSIZE"] = "2097152" - enable_fsdp = self.final_env_vars.get("CLOUDAI_ENABLE_FSDP", "0") - if enable_fsdp == "1": - logging.info( - ( - "CLOUDAI_ENABLE_FSDP is set to 1. Currently, NemoRun does not support FSDP " - "with TP communication overlap." - ) - ) - self.final_env_vars["CLOUDAI_DISABLE_TP_COMM_OVERLAP"] = "1" def _run_script(self) -> Path: tdef: NeMoRunTestDefinition = cast(NeMoRunTestDefinition, self.test_run.test.test_definition)