Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 27 additions & 13 deletions tunix/cli/base_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.


################################# LoRa #################################
_lora_config: &base_lora_config
module_path: ".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj"
rank: 16
alpha: 2.0
weight_qtype: "nf4"
tile_size: 256

################################## MESH ##################################
_mesh: &base_mesh
shape: "(2,2)"
# "('fsdp',)"
axis_names: "('fsdp','tp')"


################################## LOAD MODEL ##################################
model_config: &base_model_config
# Specify model name in style of {model_version}-{model size}, this will use to invoke model config
Expand Down Expand Up @@ -45,28 +61,26 @@ model_config: &base_model_config

# Directory used for NNX conversion if downloaded Gemma/Gemma2 from Kaggle source.
intermediate_ckpt_dir: "/tmp/intermediate_ckpt/"
################################# LoRa #################################
lora_config:
module_path: ".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj"
rank: 16
alpha: 2.0
weight_qtype: "nf4"
tile_size: 256

################################## MESH ##################################
lora_config:
<<: *base_lora_config
mesh:
shape: "(2,2)"
# "('fsdp',)"
axis_names: "('fsdp','tp')"
<<: *base_mesh


actor_model_config:
<<: *base_model_config
lora_config:
<<: *base_lora_config
mesh:
<<: *base_mesh

reference_model_config:
<<: *base_model_config

rollout_model_config:
<<: *base_model_config
mesh:
<<: *base_mesh


################################## Tokenizer ##################################
tokenizer_config:
Expand Down
22 changes: 22 additions & 0 deletions tunix/cli/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,25 @@ def __init__(self, argv: list[str], **kwargs):
keys_from_env_and_command_line,
)
self.config = raw_keys

# If no model_config related thing is overwritten, clone reference_model_config to model_config
model_config_overwritten = any(
k == "model_config" or k.startswith("model_config.")
for k in keys_from_env_and_command_line
)
if (
not model_config_overwritten
and "reference_model_config" in self.config
and self.config.get("reference_model_config")
):
logging.info(
"No model_config overrides detected, cloning reference_model_config"
" to model_config"
)
self.config["model_config"] = copy.deepcopy(
self.config["reference_model_config"]
)

self._validate_tokenizer()
self._validate_model_source(raw_keys)
self.check_supported_workflow()
Expand Down Expand Up @@ -776,6 +795,9 @@ def _load_config_from_yaml(self, config_path: str):
config_oconf = omegaconf.OmegaConf.load(config_path)
except FileNotFoundError as e:
raise ValueError(f"Config {config_path} not found.") from e
for key in list(config_oconf.keys()):
if key.startswith("_"):
del config_oconf[key]

return config_oconf

Expand Down