Skip to content

Make the specified config parameters update the pretrained config #211

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 32 commits into from
Apr 17, 2025

Conversation

jlamypoirier
Copy link
Collaborator

@jlamypoirier jlamypoirier commented Mar 27, 2025

✨ Description

Fixes #170

The specified model config now overrides the parameters in the pretrained config, instead of the load_config field determining whether they are used or discarded. This allows loading the full config from the pretrained model while also overriding some of the fields.

This is a big step towards #166, with some caveats:

  • The conversion mechanism still converts only architecture parameters
  • load_config remains architecture by default. It needs to be set explicitly to model or fast_llm for the full config to be loaded. Changing the default would be difficult because of backward compatibility issues.

This is a technically breaking change in the sense that the fields that are no longer ignored may lead to a change in behaviour, but defining them was a bad idea in the first place and I don't expect any problem in practice.

🔍 Type of change

Select all that apply:

  • 🐛 Bug fix (non-breaking change that addresses a specific issue)
  • 🚀 New feature (non-breaking change that adds functionality)
  • ⚠️ Breaking change (a change that could affect existing functionality)
  • 📈 Performance improvement/optimization (improves speed, memory usage, or efficiency)
  • 🛠️ Code refactor (non-functional changes that improve code readability, structure, etc.)
  • 📦 Dependency bump (updates dependencies, including Dockerfile or package changes)
  • 📝 Documentation change (updates documentation, including new content or typo fixes)
  • 🔧 Infrastructure/Build change (affects build process, CI/CD, or dependencies)

@jlamypoirier jlamypoirier changed the base branch from config_updates to main March 27, 2025 01:26
@jlamypoirier jlamypoirier changed the base branch from main to config_updates March 27, 2025 01:26
@jlamypoirier jlamypoirier marked this pull request as ready for review April 5, 2025 01:48
else:
expected_config["base_model"] = base_model_update

check_equal_nested(serialized_config, expected_config)
Copy link
Collaborator

@tscholak tscholak Apr 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I want to make sure I understand what this test validates. This is difficult because I don't understand what the different load-config values mean, and I suspect I'm not alone there.
As far as I can tell, we want to make sure that when we initialize a model from a saved config and some ad hoc updates, and also somehow specify how much of the saved config to load, then the final config reflects that choice correctly and precisely.
So, now drilling into the different load_config cases:

  • architecture: is this supposed to ignore everything else, like training parameters?
  • model: I guess this is architecture++, but may not include everything
  • fast_llm: pulls in everything?
  • none: ignore everything

Can you please explain all this? Otherwise I can't understand what the expected behaviour here is and if that all makes sense. Thanks

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're making a FastLLMModelConfig from a loaded one. Remember its structure

class FastLLMModelConfig(Config):
    base_model: BaseModelConfig 
    multi_stage: MultiStageConfig
    distributed: DistributedConfig

Where BaseModelConfig subclasses BaseModelArchitectureConfig which defines the core model parameters, i.e. those for which it doesn't make sense to override on an existing model.

So:

  • fast_llm: Load the whole thing and use it as default.
  • none: load nothing and use fast-llm's default.
  • model: load the base model config, but use Fast-LLM default for multi_stage and distributed.
  • architecture: load the architecture part of the base model config, but use fast-llm default for the rest.

Note that training parameters outside the model config are irrelevant, they aren't even in the checkpoint. Also for Hugging Face there is no multi_stage or distributed either, so the Fast-LLM defaults are used either way (same for all non-architecture parameters at the moment).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @jlamypoirier!

This framing confirms the problem: we are encoding internal class structure into the external loading API. I don't think that's a good idea. I don't think anybody cares to know what BaseModelArchitectureConfig is. We should just be able to say: "load the full thing, and let me override some fields."

Copy link
Collaborator

@tscholak tscholak left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the contribution. I'm a bit at a loss here. Do we actually care what part of the config to load from disk? And can't we just say that overrides should always take precedence? I hope they do, because that is what I would expect.
You're saying there's no breaking change here, but there is, even if it's just a subtle one. Actually I don't mind breaking changes, and I think what's done here now is not going far enough.
So can we be more bold and just kill load_config? Let's always load full configs, and then apply an override if it exists. Very breaking change, but also much clearer behaviour.

@tscholak
Copy link
Collaborator

tscholak commented Apr 10, 2025

Maybe I'm way off here, but instead of treating config loading almost like an interpreter pass, can we treat it like a pure transformation pipeline:

   ↓
defaults (from config class definitions)
   ↓
+ optional HF import (same schema, or mapped into Fast-LLMs)
   ↓
+ user-provided config and command-line overrides (flat or nested dict, no side effects)
   ↓
→ resolve (compute derived fields, check invariants)

Each step produces a complete dict, no state or mutation.
Overrides are clean: a flat or nested dict.
Derived fields can be recomputed anytime, and are never persisted or overridden.

Also, we can easily track here where each field value came from (default, import, override)

@jlamypoirier
Copy link
Collaborator Author

Thanks for the contribution. I'm a bit at a loss here. Do we actually care what part of the config to load from disk? And can't we just say that overrides should always take precedence? I hope they do, because that is what I would expect. You're saying there's no breaking change here, but there is, even if it's just a subtle one. Actually I don't mind breaking changes, and I think what's done here now is not going far enough. So can we be more bold and just kill load_config? Let's always load full configs, and then apply an override if it exists. Very breaking change, but also much clearer behaviour.

We cared a lot more before this PR. The architecture/base_model split was there in large part to make it possible to override some base model parameters when loading a pretrained model, but that's no longer relevant.

So there isn't much reason for using load_config= architecture anymore, backward compatibility forced me to keep it as the default for now but we should look into changing it at some point. fast_llm and none are nice to have but not essential so maybe we could just drop the option and keep model as the default and only choice.

Maybe I'm way off here, but instead of treating config loading almost like an interpreter pass, can we treat it like a pure transformation pipeline:

   ↓
defaults (from config class definitions)
   ↓
+ optional HF import (same schema, or mapped into Fast-LLMs)
   ↓
+ user-provided config and command-line overrides (flat or nested dict, no side effects)
   ↓
→ resolve (compute derived fields, check invariants)

Each step produces a complete dict, no state or mutation. Overrides are clean: a flat or nested dict. Derived fields can be recomputed anytime, and are never persisted or overridden.

Also, we can easily track here where each field value came from (default, import, override)

This is approximately what I'm trying to do here, but remembering the actual transformation pipeline would be a bad idea:

  • This picture is incomplete as there are already many variations to this pipeline (external overrides within Fast-LLM, other optional imports, etc.).
  • Reproducing a derived value would be very difficult and require basically re-running the whole thing.
  • Reproducing a part of the config (needed in a lot of cases, ex. the model config) would be mean either remembering the entire context from which it was generated (ex. the whole training config) or keeping track of a separate pipeline for reproducing that config part, and either option sounds really bad. This is the main problem I'm trying to solve here.

Since we only care about reproducing the config or any of its sub-configs and not about where parameters came from, the natural alternative is to compress the transformation pipeline into a smaller piece of information that is sufficient to do so. We were already somewhat doing that by keeping only non-default values, but here (in. #205) I'm proposing to keep track of explicitly set parameters instead which I think is a big improvements:

  • It reduces the amount of clutter in serialized configs by removing implicit defaults, so configs are more readable.
  • It keeps values explicitly set to the default in saved configs, which reflects the user's intention to put the emphasis on them.
  • It allows using any config as a config update, which drops the need for additional structure, ad-hoc design choices and hacks, etc. Things just work.

@tscholak
Copy link
Collaborator

thanks for the great response, @jlamypoirier!

So there isn't much reason for using load_config= architecture anymore, backward compatibility forced me to keep it as the default for now but we should look into changing it at some point. fast_llm and none are nice to have but not essential so maybe we could just drop the option and keep model as the default and only choice.

If we agree that architecture is obsolete and fast_llm and none aren't essential, then we should just remove load_config and always do full config loading followed by user overrides. That's easy to reason about and what people expect.
I think that trying to preserve half-broken options for backward compatibility doesn't make the system better. It just makes it more confusing. Let's be bold and simplify here!

@tscholak
Copy link
Collaborator

Remembering the transformation pipeline would be a bad idea [...] Reproducing a part of the config [...] would require remembering the entire context [...] which sounds really bad.

It is bad, and I'm not arguing for that. I'm actually arguing for the opposite. All dependencies on config should be downstream from config resolution. Reproducing a subconfig should never require knowledge of global context. And that is ensured by treating config resolution as a pure, layered transformation (i.e. deterministic: same input leads to same output, side-effect free: no in-place mutation during validation or derivation, and context-independent: doesn't depend on being invoked from some special place.) That's the core idea I'm arguing for: make config resolution a clean, deterministic pipeline. After that, the resolved config is immutable and complete. No special knowledge or context needed. You can pass any sub-tree around with confidence.

Right now, the system interleaves validation, mutation, and derivation. It's hard to tell what happened when or why. Again, we should treat config resolution explicitly as:

defaults (class definitions)
→ optional import (e.g. Hugging Face)
→ user overrides (first YAML, then command line)
→ compute derived fields (no side effects)
→ final resolved config (immutable)

Each layer is self-contained. No mutation. No inference of user intent. No implicit context dependencies. Once resolved, the config is just immutable data.

You mentioned tracking explicitly set fields. Sure, that's useful for cleaning up serialized configs. (And maybe this feature is overrated, because people only look at configs in Wandb and there the config is complete and things are manageable and easily discoverable due to search.) But cleaning up serialized configs doesn't address the real issue: users and developers need to understand how a config was built, not just what fields were manually set. Debugging is about causality, not just surface state.

We only care about reproducing the config or any of its sub-configs [...]

That only holds if you trust the system to always get it right, with no surprises. Right now though we have confusion around defaults, derivations, and implicit mutations. A clean layering model eliminates that ambiguity and makes things easier to test, reason about, and serialize correctly.

I'm not saying we need provenance tracking or transformation logs, just clear separation between inputs, transformation, and outputs. Derived fields shouldn't be serialized. They should be recomputed from inputs every time, just like computed properties or @computed_field in Pydantic.

Clarity and compactness are not at odds. But compactness without clarity leads to brittle complexity. I think we are at that point now. Let's treat it as an opportunity to simplify.

@jlamypoirier
Copy link
Collaborator Author

To answer your comments and those in #205, I think we want the same thing, we just have a different view on how to get there.

defaults (class definitions)
→ optional import (e.g. Hugging Face)
→ user overrides (first YAML, then command line)
→ compute derived fields (no side effects)
→ final resolved config (immutable)

The first steps make sense from the Runnable point of view, i.e. the final program. We'll come back to that, but for now let's focus on the last two steps which form the core of the config system.

Core config system

The "compute derived fields" and "final resolution" parts are a lot more complicated than they look and hide quite a few things which any good config system needs to address (here roughly in the order it's done in Fast-LLM):

  • Instantiate the config object from the config dict.
  • Instantiate child configs (recursively)
  • Set values dependencies between configs, ex. let the batch config know about the data-parallel size so it can pick a sensible default. We want to avoid this as much as possible, but often it's the best course of actions (ex. make configs less verbose, reduce the risk for error, allow changing the world size without re-calculating the number of sequential micro-batches manually). Omegaconf solves this problem with interpolations and resolvers, which works and has the benefit of being accessible from the yaml config, but is extremely complex and error-prone, and is unaware of the underlying config structure, so I would prefer a simpler and safer option.
  • Compute derived (aka implicit) defaults. This is different from pure derived values because they can still be overridden, we just provide sensible defaults. Ex. set mlp hidden size to 4 times the hidden size. There are plenty of known ways do do it, with __post_init__, omegaconf resolvers/interpreters, etc., so nothing new here.
  • Compute actual derived fields. These are dataclass fields with init=False, we're not doing anything new here. These are never serialized (at least in Fast-LLM, not sure about others).
  • Resolve child configs recursively, i.e. do all of these steps (except instantiation) on child configs. This is the one place Fast-LLM diverges noticeably from the alternatives. With dataclasses, pydantic and omegaconf, all but one of these steps are done in __post_init__. This means config instantiation is tied to its final resolution, so all cross-config dependencies must be resolved before instantiation. Fast-LLM unties the two by replacing __post_init__ with validate/_validate, which leaves an opportunity to set cross-config dependencies in-between. validate is still called at instantiation by default, so it's an opt-in feature.
  • Standardize and verify value types, ex. convert ints as floats, strings to enums, etc. This is quite important for convenience and to support complex types. It's relatively standard (omegaconf, pydantic), though Fast-LLM is a bit more strict to stay on the safe side.
  • Enforce in-config constraints. Again a standard with __post_init__, though I also added validators directly in the field definition for clarity (like in pydantic) with limited success.
  • Enforce cross-config constraints. The easy way to do is in a parent config, and __post_init__ works fine for it, so not much to discuss there. Another scenario to consider is when the dependency is between "cousin" configs independently of the parent, ex. BatchConfig depends on DistributedConfig, whether it's within a trainer, a hugging face wrapper, a non-wrapped inference runner, a reference model within a trainer, etc. Here validating in the parent class depends on the parent's good faith so is error-prone and makes the class usage more complicated. Fast-LLM solves this kind of issue by doing the check in the class that needs it, and forcing the parent class to help it do it. This is one less thing to worry about, because the user can just ignore and will be told if doing something wrong.
  • Freeze the final resolved config.

So Fast-LLM's config system is really standard and follows the same basic principles as dataclasses, pydantic and omegaconf, with one major exception: cross-config dependencies are resolved in instantiated configs, before their final resolutions. There are good reasons for doing so:

  • We know the full structure of the configs we're dealing with, so we have easy access to the field type, defaults and other properties, as well as the class methods, properties, etc. If needed (and we often do), we can manually resolve a child to get access to its final field values and derived fields.
  • We're dealing with objects, not dicts. This means we have full IDE support, with static checks, code completion, etc.

The main drawbacks is that we're calling the method _validate instead of __post_init__ (couldn't find an easy way around it), and that we need to check whether a child config has been resolved before using it in _validate (i.e. need to call super().validate() first; there is no such concern outside _validate because of automated validation). These are really simple to explain and address, so basic documentation should be enough.

Right now, the system interleaves validation, mutation, and derivation. It's hard to tell what happened when or why.

I'm not sure about this. I tried to keep things simple and follow the same idea as dataclasses, etc. which cram everything in __post_init__. The order is roughly fixed (as described above) but obviously depends on the developer making the right choices. We could try to improve things a bit, but I'm worried this could lead to excessive structure and complexity that would against our intention, and in some case we do need that extra flexibility anyways.

Maybe one initial step would be to convert derived fields to @cached_property or equivalent so they don't clutter the validation step.

Main program, aka runnable

Now we can come back to making a runnable program from the core config system, I.e turning user inputs into an actual config dict ready to be fed to the config system. I kept is separate in Fast-LLM, and I don't think it's controversial since it's the exact same thing as the omegaconf/hydra split. (Dataclasses and pydantic have no such equivalent, so we're really just comparing to hydra here.) What we have in Fast-LLM is a minimalistic config merging system roughly similar to hydra, with resolution order:

→ Yaml config file
→ Cli updates

That's it (and we usually skip the cli part). The defaults are implicitly first in the resolution order, but are handled in the config instantiation so don't show up here. The yaml part is basically identical to hydra (with a lot fewer features), but the cli part differs a bit:

  • There is no +, ++, - syntax. Everything just sets the value (equivalent to ++), and anything else doesn't make much sense because there is no default yet to override.
  • Setting a nested field (list, dict, child config, etc.) overrides the whole thing, ex. setting batch={"size":4} will override the entire config, so is different from batch.size=4. No idea if it's different from hydra.

This runnable resolution system seems to be working fine, but I'm open to revisiting if there is a better and/or simpler way.

Pretrained config

The core config and runnable described above are enough to make a generic and robust config system, but doesn't really have a natural place for dealing with the pretrained model config. We could do it in the runnable part as you're suggesting, but there are good reasons not to. We need to be able to deal with pretrained models and their configs as self-contained objects. So we want pretrained configs to behave just like any normal configs, i.e. like huggingface model configs, so we can play around freely with pretrained configs and models, ex. for testing, debugging, interactive inference, making models on-the-fly, deal with multiple pretrained models at once (ex. reference model for distillation), etc.

The need for a self-contained object complicates things quite a bit. What we've been doing so far (and not planning on changing) is to define a composite config made of a model config and a checkpoint loading config, and construct the combined model config (provided config + loaded config) during the cross-config dependencies step. This is enough to obtain a predictable config resolution order, though it needs to be documented somewhere. The problem so this PR attempts is that the resulting resolution order, though predictable, was a bit too complicated and unintuitive because of limitations of the config system:

→ Default
→ Yaml config + cli updates
→ An arbitrary, configurable subset of the pretrained model config, ex. architecture parameters
→ Config resolution

But now that we removed the limitations, we can do the following which makes much more sense:

→ Default
→ The pretrained model config
→ Yaml config + cli updates
→ Config resolution

All this PR is about is the order reversing part. I agree we need to deal with the arbitrary subset part, but we'll need to make sure it doesn't break existing work, so let's keep it for the next step.

Base automatically changed from config_updates to main April 14, 2025 20:14
@tscholak tscholak merged commit 1550bd1 into main Apr 17, 2025
2 checks passed
@tscholak tscholak deleted the update_pretrained_config branch April 17, 2025 23:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Make the model config override the pretrained config
2 participants