Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
bdd431c
support to expose agent config and seed hyperparameters
srivatsankrishnan Aug 7, 2025
dbdbcb0
remove random walker agent config
srivatsankrishnan Aug 7, 2025
9b80f39
debug logs
srivatsankrishnan Aug 7, 2025
b3f5c9c
debug log
srivatsankrishnan Aug 7, 2025
44176d1
debug logs
srivatsankrishnan Aug 7, 2025
8053f3c
more debug logs
srivatsankrishnan Aug 7, 2025
90e1fbf
more debug logs
srivatsankrishnan Aug 7, 2025
72385bd
remove debug logs
srivatsankrishnan Aug 7, 2025
6c72004
preserve agent_configs
srivatsankrishnan Aug 7, 2025
37f7e6b
more debugging
srivatsankrishnan Aug 7, 2025
7a099f9
debugging
srivatsankrishnan Aug 7, 2025
eddfec7
debug
srivatsankrishnan Aug 7, 2025
9c35c26
revert changes/debug logs
srivatsankrishnan Aug 7, 2025
dac116d
fix for param seeding
srivatsankrishnan Aug 8, 2025
2ae0261
more fixes
srivatsankrishnan Aug 8, 2025
0af7504
add bo agent specific fields
srivatsankrishnan Aug 8, 2025
b6cb985
debug messages
srivatsankrishnan Aug 8, 2025
2efd3bb
more debug logs
srivatsankrishnan Aug 8, 2025
d39784f
clean up and remove debug logging
srivatsankrishnan Aug 8, 2025
f59e20c
Calculate the num devices dynamically (GB200 case)
srivatsankrishnan Aug 12, 2025
620fa9f
typo (trainer.num_devices) --> trainer.devices
srivatsankrishnan Aug 12, 2025
7f00782
seperate reward value for failed constraint check
srivatsankrishnan Aug 12, 2025
262d1a1
405b refactor
srivatsankrishnan Aug 14, 2025
46c6ecc
remove the fdl
srivatsankrishnan Aug 14, 2025
1700f6b
remove tp_comm_overp disable with fsdp
srivatsankrishnan Aug 15, 2025
68490ac
use nemo userbuffers directly
srivatsankrishnan Aug 15, 2025
0b126c0
'TransformerLayerTPOverlapCfg' object is not callable fix
srivatsankrishnan Aug 15, 2025
b00ecc9
fsdp and ub cfgs
malay-nagda Aug 18, 2025
9121066
Merge branch 'main' into resolve_conflicts
srivatsankrishnan Aug 19, 2025
b7e9e61
Merge branch 'main' into agent_seeding
srivatsankrishnan Aug 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion src/cloudai/cli/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,29 @@ def handle_dse_job(runner: Runner, args: argparse.Namespace):
)
continue

agent = agent_class(env)
agent_config = test_run.test.test_definition.agent_config

if agent_config:
agent_kwargs = {
k: v for k, v in agent_config.model_dump().items()
if v is not None and k not in ['extra_params', 'seed_parameters', 'agent_type']
}

if hasattr(agent_config, 'seed_parameters') and agent_config.seed_parameters:
action_space = env.define_action_space()
resolved_seeds = test_run.test.test_definition.resolve_seed_parameters(action_space)
if resolved_seeds:
agent_kwargs['seed_parameters'] = resolved_seeds

agent_kwargs.update(agent_config.extra_params)

try:
agent = agent_class(env, **agent_kwargs)
except TypeError as e:
logging.warning(f"Agent {agent_type} doesn't support some configuration parameters: {e}")
agent = agent_class(env)
else:
agent = agent_class(env)
for step in range(agent.max_steps):
result = agent.select_action()
if result is None:
Expand Down
2 changes: 1 addition & 1 deletion src/cloudai/configurator/cloudai_gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def step(self, action: Any) -> Tuple[list, float, bool, dict]:

if not self.test_run.test.test_definition.constraint_check(self.test_run):
logging.info("Constraint check failed. Skipping step.")
return [-1.0], -1.0, True, {}
return [-1.0], -1e-6, True, {}

logging.info(f"Running step {self.test_run.step} with action {action}")
new_tr = copy.deepcopy(self.test_run)
Expand Down
30 changes: 28 additions & 2 deletions src/cloudai/models/scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@
from __future__ import annotations

from pathlib import Path
from typing import Any, Literal, Optional
from typing import Any, Literal, Optional, Union

from pydantic import BaseModel, ConfigDict, Field, ValidationError, field_serializer, field_validator, model_validator

from cloudai.core import CmdArgs, GitRepo, NsysConfiguration, Registry, Reporter, TestRun
from cloudai.models.workload import TestDefinition
from cloudai.models.workload import AgentConfig, BOAgentConfig, TestDefinition


def parse_reports_spec(
Expand Down Expand Up @@ -91,16 +91,42 @@ class TestRunModel(BaseModel):
agent: Optional[str] = None
agent_steps: Optional[int] = None
agent_metrics: list[str] = Field(default=["default"])
agent_config: Optional[Union[AgentConfig, BOAgentConfig]] = None

@field_validator('agent_config', mode='before')
@classmethod
def parse_agent_config(cls, v, info):
"""Parse agent_config based on the agent type."""

if v is None:
return None

if isinstance(v, AgentConfig):
return v

if isinstance(v, dict):
has_bo_fields = {'sobol_num_trials', 'botorch_num_trials', 'seed_parameters'} & v.keys()

is_bo_agent = v.get('agent_type') == 'bo_gp'

if has_bo_fields or is_bo_agent:
return BOAgentConfig.model_validate(v)
else:
return AgentConfig.model_validate(v)

return v

def tdef_model_dump(self, by_alias: bool) -> dict:
"""Return a dictionary with non-None values that correspond to the test definition fields."""
agent_config_dump = self.agent_config.model_dump() if self.agent_config else None
data = {
"name": self.name,
"description": self.description,
"test_template_name": self.test_template_name,
"agent": self.agent,
"agent_steps": self.agent_steps,
"agent_metrics": self.agent_metrics,
"agent_config": agent_config_dump,
"extra_container_mounts": self.extra_container_mounts,
"extra_env_vars": self.extra_env_vars if self.extra_env_vars else None,
"cmd_args": self.cmd_args.model_dump(by_alias=by_alias) if self.cmd_args else None,
Expand Down
110 changes: 109 additions & 1 deletion src/cloudai/models/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,55 @@
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union

from pydantic import BaseModel, ConfigDict, Field
from pydantic import BaseModel, ConfigDict, Field, field_validator

from cloudai.core import GitRepo, Installable, JobStatusResult, PythonExecutable, TestRun


class AgentConfig(BaseModel):
"""Base configuration class for agents used in DSE."""

model_config = ConfigDict(extra="forbid")

# Common agent parameters
random_seed: Optional[int] = None

# Allow for additional agent-specific parameters
extra_params: Dict[str, Any] = Field(default_factory=dict)


class BOAgentConfig(AgentConfig):
"""Configuration for Bayesian Optimization agent."""

# Add discriminator field to identify this as a BO config
agent_type: str = "bo_gp"

# BO-specific parameters
sobol_num_trials: Optional[int] = None
botorch_num_trials: Optional[int] = None

# Seed parameters for starting optimization from known configuration
seed_parameters: Optional[Dict[str, Any]] = None

# Allow for additional agent-specific parameters
extra_params: Dict[str, Any] = Field(default_factory=dict)

def __init__(self, **data):
# Ensure agent_type is always set even if not in input data
if 'agent_type' not in data:
data['agent_type'] = 'bo_gp'
super().__init__(**data)

def model_dump(self, **kwargs):
"""Override model_dump to ensure all BO fields are preserved."""
# Force exclude_none=False to preserve all fields
kwargs['exclude_none'] = False
result = super().model_dump(**kwargs)
# Ensure agent_type is always included to identify this as BO config
result['agent_type'] = self.agent_type
return result


class CmdArgs(BaseModel):
"""Test command arguments."""

Expand Down Expand Up @@ -107,6 +151,70 @@ class TestDefinition(BaseModel, ABC):
agent_steps: int = 1
agent_metrics: list[str] = Field(default=["default"])
agent_reward_function: str = "inverse"
agent_config: Optional[Union[AgentConfig, BOAgentConfig]] = None

@field_validator('agent_config', mode='before')
@classmethod
def parse_agent_config(cls, v, info):
"""Parse agent_config based on the agent type."""

if v is None:
return None

if isinstance(v, AgentConfig):
return v

if isinstance(v, dict):
# Check for BO-specific fields directly instead of relying on agent field
# since field validation order means agent might not be available yet
has_bo_fields = {'sobol_num_trials', 'botorch_num_trials', 'seed_parameters'} & v.keys()

# Also check for agent_type discriminator field
is_bo_agent = v.get('agent_type') == 'bo_gp'

if has_bo_fields or is_bo_agent:
# Use BOAgentConfig when BO-specific fields are present or agent_type indicates BO
return BOAgentConfig.model_validate(v)
else:
# Fall back to base AgentConfig for other cases
return AgentConfig.model_validate(v)

return v

def resolve_seed_parameters(self, action_space: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""
Resolve seed parameters by extracting values from the action space.

Args:
action_space: The flattened action space from cmd_args

Returns:
Resolved seed parameters with actual values
"""
if not self.agent_config or not hasattr(self.agent_config, 'seed_parameters'):
return None

seed_params = self.agent_config.seed_parameters
if not seed_params:
return None

resolved = {}
for param_name, value_spec in seed_params.items():
if param_name in action_space:
param_options = action_space[param_name]
if isinstance(param_options, list):
if value_spec in param_options:
resolved[param_name] = value_spec
elif isinstance(value_spec, int) and 0 <= value_spec < len(param_options):
resolved[param_name] = param_options[value_spec]
else:
resolved[param_name] = param_options[0]
else:
resolved[param_name] = param_options
else:
resolved[param_name] = value_spec

return resolved

@property
def cmd_args_dict(self) -> Dict[str, Union[str, List[str]]]:
Expand Down
12 changes: 6 additions & 6 deletions src/cloudai/test_scenario_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,15 +235,15 @@ def _prepare_tdef(self, test_info: TestRunModel) -> Tuple[Test, TestDefinition]:
if test_info.test_name not in self.test_mapping:
raise MissingTestError(test_info.test_name)
test = self.test_mapping[test_info.test_name]

test_defined = test.test_definition.model_dump(by_alias=True)
tc_defined = test_info.tdef_model_dump(by_alias=True)
test_defined = test.test_definition.model_dump(by_alias=True, exclude_none=False)
tc_defined = test_info.tdef_model_dump(by_alias=True, exclude_none=False)
merged_data = deep_merge(test_defined, tc_defined)

test.test_definition = tp.load_test_definition(merged_data, self.strict)
elif test_info.test_template_name: # test fully defined in the scenario
test = tp._parse_data(test_info.tdef_model_dump(by_alias=True), self.strict)
elif test_info.test_template_name:
test = tp._parse_data(test_info.tdef_model_dump(by_alias=True, exclude_none=False), self.strict)
else:
# this should never happen, because we check for this in the modelvalidator
raise ValueError(
f"Cannot configure test case '{test_info.id}' with both 'test_name' and 'test_template_name'."
)
Expand Down
Loading
Loading