diff --git a/tunix/cli/base_config.yaml b/tunix/cli/base_config.yaml index d949e0bf..1a9b2641 100644 --- a/tunix/cli/base_config.yaml +++ b/tunix/cli/base_config.yaml @@ -12,6 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. + +################################# LoRa ################################# +_lora_config: &base_lora_config + module_path: ".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj" + rank: 16 + alpha: 2.0 + weight_qtype: "nf4" + tile_size: 256 + +################################## MESH ################################## +_mesh: &base_mesh + shape: "(2,2)" + # "('fsdp',)" + axis_names: "('fsdp','tp')" + + ################################## LOAD MODEL ################################## model_config: &base_model_config # Specify model name in style of {model_version}-{model size}, this will use to invoke model config @@ -45,28 +61,26 @@ model_config: &base_model_config # Directory used for NNX conversion if downloaded Gemma/Gemma2 from Kaggle source. intermediate_ckpt_dir: "/tmp/intermediate_ckpt/" - ################################# LoRa ################################# - lora_config: - module_path: ".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj" - rank: 16 - alpha: 2.0 - weight_qtype: "nf4" - tile_size: 256 - ################################## MESH ################################## + lora_config: + <<: *base_lora_config mesh: - shape: "(2,2)" - # "('fsdp',)" - axis_names: "('fsdp','tp')" + <<: *base_mesh + actor_model_config: - <<: *base_model_config + lora_config: + <<: *base_lora_config + mesh: + <<: *base_mesh reference_model_config: <<: *base_model_config rollout_model_config: - <<: *base_model_config + mesh: + <<: *base_mesh + ################################## Tokenizer ################################## tokenizer_config: diff --git a/tunix/cli/config.py b/tunix/cli/config.py index 11a6dfc7..0ed093b8 100644 --- a/tunix/cli/config.py +++ b/tunix/cli/config.py @@ -174,6 +174,25 @@ def __init__(self, argv: list[str], **kwargs): keys_from_env_and_command_line, ) self.config = raw_keys + + # If no model_config related thing is overwritten, clone reference_model_config to model_config + model_config_overwritten = any( + k == "model_config" or k.startswith("model_config.") + for k in keys_from_env_and_command_line + ) + if ( + not model_config_overwritten + and "reference_model_config" in self.config + and self.config.get("reference_model_config") + ): + logging.info( + "No model_config overrides detected, cloning reference_model_config" + " to model_config" + ) + self.config["model_config"] = copy.deepcopy( + self.config["reference_model_config"] + ) + self._validate_tokenizer() self._validate_model_source(raw_keys) self.check_supported_workflow() @@ -776,6 +795,9 @@ def _load_config_from_yaml(self, config_path: str): config_oconf = omegaconf.OmegaConf.load(config_path) except FileNotFoundError as e: raise ValueError(f"Config {config_path} not found.") from e + for key in list(config_oconf.keys()): + if key.startswith("_"): + del config_oconf[key] return config_oconf